Initial commit (code only without large binaries)
This commit is contained in:
245
EdgeCommon/pkg/iplibrary/maxmind_updater.go
Normal file
245
EdgeCommon/pkg/iplibrary/maxmind_updater.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MaxMindUpdater MaxMind 数据库更新器
|
||||
type MaxMindUpdater struct {
|
||||
config *MaxMindAutoUpdateConfig
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// MaxMindAutoUpdateConfig MaxMind 自动更新配置
|
||||
type MaxMindAutoUpdateConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用自动更新
|
||||
LicenseKey string `yaml:"licenseKey" json:"licenseKey"` // MaxMind 许可证密钥
|
||||
UpdateURL string `yaml:"updateURL" json:"updateURL"` // 更新 URL(默认使用 MaxMind 官方)
|
||||
UpdateInterval string `yaml:"updateInterval" json:"updateInterval"` // 更新间隔(如 "7d", "24h")
|
||||
CityDBPath string `yaml:"cityDBPath" json:"cityDBPath"` // City 数据库路径
|
||||
ASNDBPath string `yaml:"asnDBPath" json:"asnDBPath"` // ASN 数据库路径(可选)
|
||||
}
|
||||
|
||||
// NewMaxMindUpdater 创建更新器
|
||||
func NewMaxMindUpdater(config *MaxMindAutoUpdateConfig) *MaxMindUpdater {
|
||||
if config == nil {
|
||||
config = &MaxMindAutoUpdateConfig{}
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if len(config.UpdateURL) == 0 {
|
||||
config.UpdateURL = "https://download.maxmind.com/app/geoip_download"
|
||||
}
|
||||
if len(config.UpdateInterval) == 0 {
|
||||
config.UpdateInterval = "7d" // 默认 7 天
|
||||
}
|
||||
|
||||
return &MaxMindUpdater{
|
||||
config: config,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动更新(定时任务)
|
||||
func (this *MaxMindUpdater) Start() {
|
||||
if !this.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if len(this.config.LicenseKey) == 0 {
|
||||
return // 没有许可证密钥,无法更新
|
||||
}
|
||||
|
||||
// 解析更新间隔
|
||||
interval, err := this.parseInterval(this.config.UpdateInterval)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 立即执行一次更新检查
|
||||
go this.updateOnce()
|
||||
|
||||
// 定时更新
|
||||
ticker := time.NewTicker(interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
this.updateOnce()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// updateOnce 执行一次更新检查
|
||||
func (this *MaxMindUpdater) updateOnce() {
|
||||
// 更新 City 数据库
|
||||
if len(this.config.CityDBPath) > 0 {
|
||||
err := this.downloadDatabase("GeoLite2-City", this.config.CityDBPath)
|
||||
if err != nil {
|
||||
// 记录错误但不中断
|
||||
return
|
||||
}
|
||||
// 重新加载 IP 库
|
||||
this.reloadLibrary()
|
||||
}
|
||||
|
||||
// 更新 ASN 数据库(如果配置了)
|
||||
if len(this.config.ASNDBPath) > 0 {
|
||||
_ = this.downloadDatabase("GeoLite2-ASN", this.config.ASNDBPath)
|
||||
// 重新加载 IP 库
|
||||
this.reloadLibrary()
|
||||
}
|
||||
}
|
||||
|
||||
// downloadDatabase 下载数据库
|
||||
func (this *MaxMindUpdater) downloadDatabase(dbType, targetPath string) error {
|
||||
// 构建下载 URL
|
||||
url := fmt.Sprintf("%s?edition_id=%s&license_key=%s&suffix=tar.gz",
|
||||
this.config.UpdateURL,
|
||||
dbType,
|
||||
this.config.LicenseKey,
|
||||
)
|
||||
|
||||
// 下载文件
|
||||
resp, err := this.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed: status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解压并提取 .mmdb 文件
|
||||
return this.extractMMDB(resp.Body, targetPath)
|
||||
}
|
||||
|
||||
// extractMMDB 从 tar.gz 中提取 .mmdb 文件
|
||||
func (this *MaxMindUpdater) extractMMDB(reader io.Reader, targetPath string) error {
|
||||
gzReader, err := gzip.NewReader(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create gzip reader failed: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(gzReader)
|
||||
|
||||
// 确保目标目录存在
|
||||
targetDir := filepath.Dir(targetPath)
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
return fmt.Errorf("create target directory failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建临时文件
|
||||
tmpFile := targetPath + ".tmp"
|
||||
outFile, err := os.Create(tmpFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file failed: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// 查找 .mmdb 文件
|
||||
found := false
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar failed: %w", err)
|
||||
}
|
||||
|
||||
if filepath.Ext(header.Name) == ".mmdb" {
|
||||
// 复制文件
|
||||
_, err = io.Copy(outFile, tarReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copy file failed: %w", err)
|
||||
}
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf(".mmdb file not found in archive")
|
||||
}
|
||||
|
||||
// 关闭临时文件
|
||||
if err := outFile.Close(); err != nil {
|
||||
return fmt.Errorf("close temp file failed: %w", err)
|
||||
}
|
||||
|
||||
// 原子替换
|
||||
if err := os.Rename(tmpFile, targetPath); err != nil {
|
||||
return fmt.Errorf("rename temp file failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadLibrary 重新加载 IP 库
|
||||
func (this *MaxMindUpdater) reloadLibrary() {
|
||||
// 如果当前使用的是 MaxMind,重新加载
|
||||
libraryLocker.Lock()
|
||||
defer libraryLocker.Unlock()
|
||||
|
||||
// 检查当前库是否是 MaxMind
|
||||
if defaultLibrary != nil && defaultLibrary.reader != nil {
|
||||
if _, ok := defaultLibrary.reader.(*MaxMindReader); ok {
|
||||
// 重新创建 MaxMind Reader
|
||||
reader, err := NewMaxMindReader(this.config.CityDBPath, this.config.ASNDBPath)
|
||||
if err == nil {
|
||||
// 销毁旧的
|
||||
defaultLibrary.Destroy()
|
||||
// 创建新的
|
||||
defaultLibrary = NewIPLibraryWithReader(reader)
|
||||
commonLibrary = defaultLibrary
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseInterval 解析时间间隔字符串(如 "7d", "24h", "1h30m")
|
||||
func (this *MaxMindUpdater) parseInterval(intervalStr string) (time.Duration, error) {
|
||||
// 支持常见格式
|
||||
if len(intervalStr) == 0 {
|
||||
return 7 * 24 * time.Hour, nil // 默认 7 天
|
||||
}
|
||||
|
||||
// 尝试直接解析
|
||||
duration, err := time.ParseDuration(intervalStr)
|
||||
if err == nil {
|
||||
return duration, nil
|
||||
}
|
||||
|
||||
// 尝试解析 "7d" 格式
|
||||
var days int
|
||||
var hours int
|
||||
var minutes int
|
||||
_, err = fmt.Sscanf(intervalStr, "%dd", &days)
|
||||
if err == nil {
|
||||
return time.Duration(days) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
_, err = fmt.Sscanf(intervalStr, "%dh", &hours)
|
||||
if err == nil {
|
||||
return time.Duration(hours) * time.Hour, nil
|
||||
}
|
||||
|
||||
_, err = fmt.Sscanf(intervalStr, "%dm", &minutes)
|
||||
if err == nil {
|
||||
return time.Duration(minutes) * time.Minute, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("invalid interval format: %s", intervalStr)
|
||||
}
|
||||
Reference in New Issue
Block a user