Files
waf-platform/EdgeCommon/pkg/iplibrary/maxmind_updater.go

246 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
)
// MaxMindUpdater MaxMind 数据库更新器
type MaxMindUpdater struct {
config *MaxMindAutoUpdateConfig
httpClient *http.Client
}
// MaxMindAutoUpdateConfig MaxMind 自动更新配置
type MaxMindAutoUpdateConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用自动更新
LicenseKey string `yaml:"licenseKey" json:"licenseKey"` // MaxMind 许可证密钥
UpdateURL string `yaml:"updateURL" json:"updateURL"` // 更新 URL默认使用 MaxMind 官方)
UpdateInterval string `yaml:"updateInterval" json:"updateInterval"` // 更新间隔(如 "7d", "24h"
CityDBPath string `yaml:"cityDBPath" json:"cityDBPath"` // City 数据库路径
ASNDBPath string `yaml:"asnDBPath" json:"asnDBPath"` // ASN 数据库路径(可选)
}
// NewMaxMindUpdater 创建更新器
func NewMaxMindUpdater(config *MaxMindAutoUpdateConfig) *MaxMindUpdater {
if config == nil {
config = &MaxMindAutoUpdateConfig{}
}
// 设置默认值
if len(config.UpdateURL) == 0 {
config.UpdateURL = "https://download.maxmind.com/app/geoip_download"
}
if len(config.UpdateInterval) == 0 {
config.UpdateInterval = "7d" // 默认 7 天
}
return &MaxMindUpdater{
config: config,
httpClient: &http.Client{
Timeout: 30 * time.Minute,
},
}
}
// Start 启动自动更新(定时任务)
func (this *MaxMindUpdater) Start() {
if !this.config.Enabled {
return
}
if len(this.config.LicenseKey) == 0 {
return // 没有许可证密钥,无法更新
}
// 解析更新间隔
interval, err := this.parseInterval(this.config.UpdateInterval)
if err != nil {
return
}
// 立即执行一次更新检查
go this.updateOnce()
// 定时更新
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
this.updateOnce()
}
}()
}
// updateOnce 执行一次更新检查
func (this *MaxMindUpdater) updateOnce() {
// 更新 City 数据库
if len(this.config.CityDBPath) > 0 {
err := this.downloadDatabase("GeoLite2-City", this.config.CityDBPath)
if err != nil {
// 记录错误但不中断
return
}
// 重新加载 IP 库
this.reloadLibrary()
}
// 更新 ASN 数据库(如果配置了)
if len(this.config.ASNDBPath) > 0 {
_ = this.downloadDatabase("GeoLite2-ASN", this.config.ASNDBPath)
// 重新加载 IP 库
this.reloadLibrary()
}
}
// downloadDatabase 下载数据库
func (this *MaxMindUpdater) downloadDatabase(dbType, targetPath string) error {
// 构建下载 URL
url := fmt.Sprintf("%s?edition_id=%s&license_key=%s&suffix=tar.gz",
this.config.UpdateURL,
dbType,
this.config.LicenseKey,
)
// 下载文件
resp, err := this.httpClient.Get(url)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: status code %d", resp.StatusCode)
}
// 解压并提取 .mmdb 文件
return this.extractMMDB(resp.Body, targetPath)
}
// extractMMDB 从 tar.gz 中提取 .mmdb 文件
func (this *MaxMindUpdater) extractMMDB(reader io.Reader, targetPath string) error {
gzReader, err := gzip.NewReader(reader)
if err != nil {
return fmt.Errorf("create gzip reader failed: %w", err)
}
defer gzReader.Close()
tarReader := tar.NewReader(gzReader)
// 确保目标目录存在
targetDir := filepath.Dir(targetPath)
if err := os.MkdirAll(targetDir, 0755); err != nil {
return fmt.Errorf("create target directory failed: %w", err)
}
// 创建临时文件
tmpFile := targetPath + ".tmp"
outFile, err := os.Create(tmpFile)
if err != nil {
return fmt.Errorf("create temp file failed: %w", err)
}
defer outFile.Close()
// 查找 .mmdb 文件
found := false
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("read tar failed: %w", err)
}
if filepath.Ext(header.Name) == ".mmdb" {
// 复制文件
_, err = io.Copy(outFile, tarReader)
if err != nil {
return fmt.Errorf("copy file failed: %w", err)
}
found = true
break
}
}
if !found {
return fmt.Errorf(".mmdb file not found in archive")
}
// 关闭临时文件
if err := outFile.Close(); err != nil {
return fmt.Errorf("close temp file failed: %w", err)
}
// 原子替换
if err := os.Rename(tmpFile, targetPath); err != nil {
return fmt.Errorf("rename temp file failed: %w", err)
}
return nil
}
// reloadLibrary 重新加载 IP 库
func (this *MaxMindUpdater) reloadLibrary() {
// 如果当前使用的是 MaxMind重新加载
libraryLocker.Lock()
defer libraryLocker.Unlock()
// 检查当前库是否是 MaxMind
if defaultLibrary != nil && defaultLibrary.reader != nil {
if _, ok := defaultLibrary.reader.(*MaxMindReader); ok {
// 重新创建 MaxMind Reader
reader, err := NewMaxMindReader(this.config.CityDBPath, this.config.ASNDBPath)
if err == nil {
// 销毁旧的
defaultLibrary.Destroy()
// 创建新的
defaultLibrary = NewIPLibraryWithReader(reader)
commonLibrary = defaultLibrary
}
}
}
}
// parseInterval 解析时间间隔字符串(如 "7d", "24h", "1h30m"
func (this *MaxMindUpdater) parseInterval(intervalStr string) (time.Duration, error) {
// 支持常见格式
if len(intervalStr) == 0 {
return 7 * 24 * time.Hour, nil // 默认 7 天
}
// 尝试直接解析
duration, err := time.ParseDuration(intervalStr)
if err == nil {
return duration, nil
}
// 尝试解析 "7d" 格式
var days int
var hours int
var minutes int
_, err = fmt.Sscanf(intervalStr, "%dd", &days)
if err == nil {
return time.Duration(days) * 24 * time.Hour, nil
}
_, err = fmt.Sscanf(intervalStr, "%dh", &hours)
if err == nil {
return time.Duration(hours) * time.Hour, nil
}
_, err = fmt.Sscanf(intervalStr, "%dm", &minutes)
if err == nil {
return time.Duration(minutes) * time.Minute, nil
}
return 0, fmt.Errorf("invalid interval format: %s", intervalStr)
}