232 lines
6.1 KiB
Go
232 lines
6.1 KiB
Go
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||
//go:build plus
|
||
|
||
package nodes
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"errors"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||
"github.com/TeaOSLab/EdgeNode/internal/http3"
|
||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||
"github.com/iwind/TeaGo/types"
|
||
"net"
|
||
"net/http"
|
||
"regexp"
|
||
"sync"
|
||
"sync/atomic"
|
||
)
|
||
|
||
var sharedHTTP3Manager = NewHTTP3Manager()
|
||
|
||
func init() {
|
||
if !teaconst.IsMain {
|
||
return
|
||
}
|
||
|
||
var listener = &HTTPListener{
|
||
isHTTPS: true,
|
||
isHTTP3: true,
|
||
}
|
||
events.On(events.EventLoaded, func() {
|
||
sharedListenerManager.http3Listener = listener // 注册到ListenerManager,以便统计用
|
||
})
|
||
|
||
var eventLocker = sync.Mutex{}
|
||
events.OnEvents([]events.Event{events.EventReload, events.EventReloadSomeServers}, func() {
|
||
go func() {
|
||
eventLocker.Lock()
|
||
defer eventLocker.Unlock()
|
||
|
||
if sharedNodeConfig == nil {
|
||
return
|
||
}
|
||
|
||
_ = sharedHTTP3Manager.Update(sharedNodeConfig.HTTP3Policies)
|
||
sharedHTTP3Manager.UpdateHTTPListener(listener)
|
||
|
||
listener.Reload(sharedNodeConfig.HTTP3Group())
|
||
}()
|
||
})
|
||
}
|
||
|
||
// HTTP3Manager HTTP3管理器
|
||
type HTTP3Manager struct {
|
||
locker sync.RWMutex
|
||
|
||
hasHTTP3 bool
|
||
|
||
policies map[int64]*nodeconfigs.HTTP3Policy // clusterId => *HTTP3Policy
|
||
serverMap map[int]*http3.Server // port => *Server
|
||
mobileUserAgentReg *regexp.Regexp
|
||
|
||
httpListener *HTTPListener
|
||
tlsConfig *tls.Config
|
||
}
|
||
|
||
func NewHTTP3Manager() *HTTP3Manager {
|
||
return &HTTP3Manager{
|
||
policies: map[int64]*nodeconfigs.HTTP3Policy{},
|
||
serverMap: map[int]*http3.Server{},
|
||
mobileUserAgentReg: regexp.MustCompile(`(?i)(iPhone|Android)`),
|
||
}
|
||
}
|
||
|
||
// Update 更新配置
|
||
// m: clusterId => *HTTP3Policy
|
||
func (this *HTTP3Manager) Update(m map[int64]*nodeconfigs.HTTP3Policy) error {
|
||
this.locker.Lock()
|
||
defer this.locker.Unlock()
|
||
|
||
// 启动新的
|
||
var newPolicyMap = map[int64]*nodeconfigs.HTTP3Policy{} // clusterId => *HTTP3Policy
|
||
var newPorts = map[int]bool{} // port => bool
|
||
for clusterId, policy := range m {
|
||
if policy.IsOn && policy.Port > 0 {
|
||
this.policies[clusterId] = policy
|
||
newPolicyMap[clusterId] = policy
|
||
|
||
var port = policy.Port
|
||
newPorts[port] = true
|
||
|
||
_, existPort := this.serverMap[port]
|
||
if !existPort {
|
||
server, err := this.createServer(port)
|
||
if err != nil {
|
||
remotelogs.Error("HTTP3_MANAGER", "start port '"+types.String(port)+"' failed: "+err.Error())
|
||
continue
|
||
}
|
||
this.serverMap[port] = server
|
||
remotelogs.Debug("HTTP3_MANAGER", "start port '"+types.String(port)+"'")
|
||
}
|
||
}
|
||
}
|
||
this.policies = newPolicyMap
|
||
|
||
// 关闭老的
|
||
for port, server := range this.serverMap {
|
||
if !newPorts[port] {
|
||
_ = server.Close()
|
||
delete(this.serverMap, port)
|
||
remotelogs.Debug("HTTP3_MANAGER", "close port '"+types.String(port)+"'")
|
||
}
|
||
}
|
||
|
||
this.hasHTTP3 = len(this.serverMap) > 0
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateHTTPListener 更新Listener
|
||
// 这里的Listener只是为了方便复用HTTPListener的相关方法
|
||
func (this *HTTP3Manager) UpdateHTTPListener(listener *HTTPListener) {
|
||
this.locker.Lock()
|
||
this.httpListener = listener
|
||
if listener != nil {
|
||
this.tlsConfig = listener.buildTLSConfig()
|
||
}
|
||
this.locker.Unlock()
|
||
}
|
||
|
||
// ProcessHTTP3Headers 处理HTTP3相关Headers
|
||
func (this *HTTP3Manager) ProcessHTTP3Headers(userAgent string, headers http.Header, clusterId int64) {
|
||
// 这里不要加锁,以便于提升性能
|
||
if !this.hasHTTP3 {
|
||
return
|
||
}
|
||
|
||
this.locker.RLock()
|
||
defer this.locker.RUnlock()
|
||
|
||
// 再次准确检查
|
||
if !this.hasHTTP3 {
|
||
return
|
||
}
|
||
|
||
policy, ok := this.policies[clusterId]
|
||
if !ok {
|
||
return
|
||
}
|
||
if policy.IsOn && policy.Port > 0 && (policy.SupportMobileBrowsers || !this.mobileUserAgentReg.MatchString(userAgent)) {
|
||
// TODO 版本好和有效期可以在策略里设置
|
||
headers.Set("Alt-Svc", `h3=":`+types.String(policy.Port)+`"; ma=2592000,h3-29=":`+types.String(policy.Port)+`"; ma=2592000`)
|
||
}
|
||
}
|
||
|
||
// 创建server
|
||
func (this *HTTP3Manager) createServer(port int) (*http3.Server, error) {
|
||
var addr = ":" + types.String(port)
|
||
listener, err := ListenHTTP3(addr, &tls.Config{
|
||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
this.locker.RLock()
|
||
var tlsConfig = this.tlsConfig
|
||
this.locker.RUnlock()
|
||
|
||
if tlsConfig != nil && tlsConfig.GetConfigForClient != nil {
|
||
return tlsConfig.GetConfigForClient(info)
|
||
}
|
||
|
||
return nil, errors.New("http3: no tls config")
|
||
},
|
||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||
this.locker.RLock()
|
||
var tlsConfig = this.tlsConfig
|
||
this.locker.RUnlock()
|
||
|
||
if tlsConfig != nil && tlsConfig.GetCertificate != nil {
|
||
return tlsConfig.GetCertificate(clientInfo)
|
||
}
|
||
|
||
return nil, errors.New("http3: no tls config")
|
||
},
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var server = &http3.Server{
|
||
Addr: ":" + types.String(port),
|
||
Handler: http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
|
||
if this.httpListener != nil {
|
||
var servePortString = "443"
|
||
if len(req.Host) > 0 {
|
||
_, hostPortString, hostErr := net.SplitHostPort(req.Host)
|
||
if hostErr == nil && len(hostPortString) > 0 {
|
||
servePortString = hostPortString
|
||
}
|
||
}
|
||
this.httpListener.ServeHTTPWithAddr(writer, req, ":"+servePortString)
|
||
}
|
||
}),
|
||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||
if this.httpListener == nil {
|
||
return
|
||
}
|
||
switch state {
|
||
case http.StateNew:
|
||
atomic.AddInt64(&this.httpListener.countActiveConnections, 1)
|
||
case http.StateClosed:
|
||
atomic.AddInt64(&this.httpListener.countActiveConnections, -1)
|
||
default:
|
||
// do nothing
|
||
}
|
||
},
|
||
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
|
||
return context.WithValue(ctx, HTTPConnContextKey, conn)
|
||
},
|
||
}
|
||
go func() {
|
||
err = server.Serve(listener)
|
||
if err != nil {
|
||
remotelogs.Error("HTTP3_MANAGER", "serve '"+addr+"' failed: "+err.Error())
|
||
this.locker.Lock()
|
||
delete(this.serverMap, port)
|
||
this.locker.Unlock()
|
||
}
|
||
}()
|
||
return server, nil
|
||
}
|