package rpc import ( "context" "crypto/tls" "encoding/base64" "errors" "fmt" "net/url" "sync" "time" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeHttpDNS/internal/configs" teaconst "github.com/TeaOSLab/EdgeHttpDNS/internal/const" "github.com/TeaOSLab/EdgeHttpDNS/internal/encrypt" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/rands" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "sync/atomic" ) type RPCClient struct { apiConfig *configs.APIConfig conns []*grpc.ClientConn locker sync.RWMutex NodeTaskRPC pb.NodeTaskServiceClient NodeValueRPC pb.NodeValueServiceClient HTTPDNSNodeRPC pb.HTTPDNSNodeServiceClient HTTPDNSClusterRPC pb.HTTPDNSClusterServiceClient HTTPDNSAppRPC pb.HTTPDNSAppServiceClient HTTPDNSDomainRPC pb.HTTPDNSDomainServiceClient HTTPDNSRuleRPC pb.HTTPDNSRuleServiceClient HTTPDNSRuntimeLogRPC pb.HTTPDNSRuntimeLogServiceClient HTTPDNSAccessLogRPC pb.HTTPDNSAccessLogServiceClient HTTPDNSSandboxRPC pb.HTTPDNSSandboxServiceClient totalRequests int64 failedRequests int64 totalCostMs int64 } func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) { if apiConfig == nil { return nil, errors.New("api config should not be nil") } client := &RPCClient{apiConfig: apiConfig} client.NodeTaskRPC = pb.NewNodeTaskServiceClient(client) client.NodeValueRPC = pb.NewNodeValueServiceClient(client) client.HTTPDNSNodeRPC = pb.NewHTTPDNSNodeServiceClient(client) client.HTTPDNSClusterRPC = pb.NewHTTPDNSClusterServiceClient(client) client.HTTPDNSAppRPC = pb.NewHTTPDNSAppServiceClient(client) client.HTTPDNSDomainRPC = pb.NewHTTPDNSDomainServiceClient(client) client.HTTPDNSRuleRPC = pb.NewHTTPDNSRuleServiceClient(client) client.HTTPDNSRuntimeLogRPC = pb.NewHTTPDNSRuntimeLogServiceClient(client) client.HTTPDNSAccessLogRPC = pb.NewHTTPDNSAccessLogServiceClient(client) client.HTTPDNSSandboxRPC = pb.NewHTTPDNSSandboxServiceClient(client) err := client.init() if err != nil { return nil, err } return client, nil } func (c *RPCClient) Context() context.Context { ctx := context.Background() payload := maps.Map{ "timestamp": time.Now().Unix(), "type": "httpdns", "userId": 0, } method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, c.apiConfig.Secret, c.apiConfig.NodeId) if err != nil { return context.Background() } encrypted, err := method.Encrypt(payload.AsJSON()) if err != nil { return context.Background() } token := base64.StdEncoding.EncodeToString(encrypted) return metadata.AppendToOutgoingContext(ctx, "nodeId", c.apiConfig.NodeId, "token", token) } func (c *RPCClient) UpdateConfig(config *configs.APIConfig) error { c.apiConfig = config c.locker.Lock() defer c.locker.Unlock() return c.init() } func (c *RPCClient) init() error { conns := []*grpc.ClientConn{} for _, endpoint := range c.apiConfig.RPCEndpoints { u, err := url.Parse(endpoint) if err != nil { return fmt.Errorf("parse endpoint failed: %w", err) } var conn *grpc.ClientConn callOptions := grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(128<<20), grpc.MaxCallSendMsgSize(128<<20), grpc.UseCompressor(gzip.Name), ) keepaliveParams := grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, }) if u.Scheme == "http" { conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions, keepaliveParams) } else if u.Scheme == "https" { conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ InsecureSkipVerify: true, })), callOptions, keepaliveParams) } else { return errors.New("invalid endpoint scheme '" + u.Scheme + "'") } if err != nil { return err } conns = append(conns, conn) } if len(conns) == 0 { return errors.New("no available rpc endpoints") } c.conns = conns return nil } func (c *RPCClient) pickConn() *grpc.ClientConn { c.locker.RLock() defer c.locker.RUnlock() countConns := len(c.conns) if countConns == 0 { return nil } if countConns == 1 { return c.conns[0] } for _, state := range []connectivity.State{ connectivity.Ready, connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure, } { available := []*grpc.ClientConn{} for _, conn := range c.conns { if conn.GetState() == state { available = append(available, conn) } } if len(available) > 0 { return c.randConn(available) } } return c.randConn(c.conns) } func (c *RPCClient) randConn(conns []*grpc.ClientConn) *grpc.ClientConn { l := len(conns) if l == 0 { return nil } if l == 1 { return conns[0] } return conns[rands.Int(0, l-1)] } func (c *RPCClient) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { conn := c.pickConn() if conn == nil { return errors.New("can not get available grpc connection") } atomic.AddInt64(&c.totalRequests, 1) start := time.Now() err := conn.Invoke(ctx, method, args, reply, opts...) costMs := time.Since(start).Milliseconds() atomic.AddInt64(&c.totalCostMs, costMs) if err != nil { atomic.AddInt64(&c.failedRequests, 1) } return err } func (c *RPCClient) GetAndResetMetrics() (total int64, failed int64, avgCostSeconds float64) { total = atomic.SwapInt64(&c.totalRequests, 0) failed = atomic.SwapInt64(&c.failedRequests, 0) costMs := atomic.SwapInt64(&c.totalCostMs, 0) if total > 0 { avgCostSeconds = float64(costMs) / float64(total) / 1000.0 } return } func (c *RPCClient) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { conn := c.pickConn() if conn == nil { return nil, errors.New("can not get available grpc connection") } return conn.NewStream(ctx, desc, method, opts...) }