246 lines
6.1 KiB
Go
246 lines
6.1 KiB
Go
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||
|
||
package iplibrary
|
||
|
||
import (
|
||
"archive/tar"
|
||
"compress/gzip"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"time"
|
||
)
|
||
|
||
// MaxMindUpdater MaxMind 数据库更新器
|
||
type MaxMindUpdater struct {
|
||
config *MaxMindAutoUpdateConfig
|
||
httpClient *http.Client
|
||
}
|
||
|
||
// MaxMindAutoUpdateConfig MaxMind 自动更新配置
|
||
type MaxMindAutoUpdateConfig struct {
|
||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用自动更新
|
||
LicenseKey string `yaml:"licenseKey" json:"licenseKey"` // MaxMind 许可证密钥
|
||
UpdateURL string `yaml:"updateURL" json:"updateURL"` // 更新 URL(默认使用 MaxMind 官方)
|
||
UpdateInterval string `yaml:"updateInterval" json:"updateInterval"` // 更新间隔(如 "7d", "24h")
|
||
CityDBPath string `yaml:"cityDBPath" json:"cityDBPath"` // City 数据库路径
|
||
ASNDBPath string `yaml:"asnDBPath" json:"asnDBPath"` // ASN 数据库路径(可选)
|
||
}
|
||
|
||
// NewMaxMindUpdater 创建更新器
|
||
func NewMaxMindUpdater(config *MaxMindAutoUpdateConfig) *MaxMindUpdater {
|
||
if config == nil {
|
||
config = &MaxMindAutoUpdateConfig{}
|
||
}
|
||
|
||
// 设置默认值
|
||
if len(config.UpdateURL) == 0 {
|
||
config.UpdateURL = "https://download.maxmind.com/app/geoip_download"
|
||
}
|
||
if len(config.UpdateInterval) == 0 {
|
||
config.UpdateInterval = "7d" // 默认 7 天
|
||
}
|
||
|
||
return &MaxMindUpdater{
|
||
config: config,
|
||
httpClient: &http.Client{
|
||
Timeout: 30 * time.Minute,
|
||
},
|
||
}
|
||
}
|
||
|
||
// Start 启动自动更新(定时任务)
|
||
func (this *MaxMindUpdater) Start() {
|
||
if !this.config.Enabled {
|
||
return
|
||
}
|
||
|
||
if len(this.config.LicenseKey) == 0 {
|
||
return // 没有许可证密钥,无法更新
|
||
}
|
||
|
||
// 解析更新间隔
|
||
interval, err := this.parseInterval(this.config.UpdateInterval)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 立即执行一次更新检查
|
||
go this.updateOnce()
|
||
|
||
// 定时更新
|
||
ticker := time.NewTicker(interval)
|
||
go func() {
|
||
for range ticker.C {
|
||
this.updateOnce()
|
||
}
|
||
}()
|
||
}
|
||
|
||
// updateOnce 执行一次更新检查
|
||
func (this *MaxMindUpdater) updateOnce() {
|
||
// 更新 City 数据库
|
||
if len(this.config.CityDBPath) > 0 {
|
||
err := this.downloadDatabase("GeoLite2-City", this.config.CityDBPath)
|
||
if err != nil {
|
||
// 记录错误但不中断
|
||
return
|
||
}
|
||
// 重新加载 IP 库
|
||
this.reloadLibrary()
|
||
}
|
||
|
||
// 更新 ASN 数据库(如果配置了)
|
||
if len(this.config.ASNDBPath) > 0 {
|
||
_ = this.downloadDatabase("GeoLite2-ASN", this.config.ASNDBPath)
|
||
// 重新加载 IP 库
|
||
this.reloadLibrary()
|
||
}
|
||
}
|
||
|
||
// downloadDatabase 下载数据库
|
||
func (this *MaxMindUpdater) downloadDatabase(dbType, targetPath string) error {
|
||
// 构建下载 URL
|
||
url := fmt.Sprintf("%s?edition_id=%s&license_key=%s&suffix=tar.gz",
|
||
this.config.UpdateURL,
|
||
dbType,
|
||
this.config.LicenseKey,
|
||
)
|
||
|
||
// 下载文件
|
||
resp, err := this.httpClient.Get(url)
|
||
if err != nil {
|
||
return fmt.Errorf("download failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("download failed: status code %d", resp.StatusCode)
|
||
}
|
||
|
||
// 解压并提取 .mmdb 文件
|
||
return this.extractMMDB(resp.Body, targetPath)
|
||
}
|
||
|
||
// extractMMDB 从 tar.gz 中提取 .mmdb 文件
|
||
func (this *MaxMindUpdater) extractMMDB(reader io.Reader, targetPath string) error {
|
||
gzReader, err := gzip.NewReader(reader)
|
||
if err != nil {
|
||
return fmt.Errorf("create gzip reader failed: %w", err)
|
||
}
|
||
defer gzReader.Close()
|
||
|
||
tarReader := tar.NewReader(gzReader)
|
||
|
||
// 确保目标目录存在
|
||
targetDir := filepath.Dir(targetPath)
|
||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||
return fmt.Errorf("create target directory failed: %w", err)
|
||
}
|
||
|
||
// 创建临时文件
|
||
tmpFile := targetPath + ".tmp"
|
||
outFile, err := os.Create(tmpFile)
|
||
if err != nil {
|
||
return fmt.Errorf("create temp file failed: %w", err)
|
||
}
|
||
defer outFile.Close()
|
||
|
||
// 查找 .mmdb 文件
|
||
found := false
|
||
for {
|
||
header, err := tarReader.Next()
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return fmt.Errorf("read tar failed: %w", err)
|
||
}
|
||
|
||
if filepath.Ext(header.Name) == ".mmdb" {
|
||
// 复制文件
|
||
_, err = io.Copy(outFile, tarReader)
|
||
if err != nil {
|
||
return fmt.Errorf("copy file failed: %w", err)
|
||
}
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !found {
|
||
return fmt.Errorf(".mmdb file not found in archive")
|
||
}
|
||
|
||
// 关闭临时文件
|
||
if err := outFile.Close(); err != nil {
|
||
return fmt.Errorf("close temp file failed: %w", err)
|
||
}
|
||
|
||
// 原子替换
|
||
if err := os.Rename(tmpFile, targetPath); err != nil {
|
||
return fmt.Errorf("rename temp file failed: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// reloadLibrary 重新加载 IP 库
|
||
func (this *MaxMindUpdater) reloadLibrary() {
|
||
// 如果当前使用的是 MaxMind,重新加载
|
||
libraryLocker.Lock()
|
||
defer libraryLocker.Unlock()
|
||
|
||
// 检查当前库是否是 MaxMind
|
||
if defaultLibrary != nil && defaultLibrary.reader != nil {
|
||
if _, ok := defaultLibrary.reader.(*MaxMindReader); ok {
|
||
// 重新创建 MaxMind Reader
|
||
reader, err := NewMaxMindReader(this.config.CityDBPath, this.config.ASNDBPath)
|
||
if err == nil {
|
||
// 销毁旧的
|
||
defaultLibrary.Destroy()
|
||
// 创建新的
|
||
defaultLibrary = NewIPLibraryWithReader(reader)
|
||
commonLibrary = defaultLibrary
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// parseInterval 解析时间间隔字符串(如 "7d", "24h", "1h30m")
|
||
func (this *MaxMindUpdater) parseInterval(intervalStr string) (time.Duration, error) {
|
||
// 支持常见格式
|
||
if len(intervalStr) == 0 {
|
||
return 7 * 24 * time.Hour, nil // 默认 7 天
|
||
}
|
||
|
||
// 尝试直接解析
|
||
duration, err := time.ParseDuration(intervalStr)
|
||
if err == nil {
|
||
return duration, nil
|
||
}
|
||
|
||
// 尝试解析 "7d" 格式
|
||
var days int
|
||
var hours int
|
||
var minutes int
|
||
_, err = fmt.Sscanf(intervalStr, "%dd", &days)
|
||
if err == nil {
|
||
return time.Duration(days) * 24 * time.Hour, nil
|
||
}
|
||
|
||
_, err = fmt.Sscanf(intervalStr, "%dh", &hours)
|
||
if err == nil {
|
||
return time.Duration(hours) * time.Hour, nil
|
||
}
|
||
|
||
_, err = fmt.Sscanf(intervalStr, "%dm", &minutes)
|
||
if err == nil {
|
||
return time.Duration(minutes) * time.Minute, nil
|
||
}
|
||
|
||
return 0, fmt.Errorf("invalid interval format: %s", intervalStr)
|
||
}
|