Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

feature/internal/grpc: retry: vendor go-grpc-middleware testing/testpb package #64198

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions internal/grpc/retry/testpb/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//dev:go_defs.bzl", "go_test")

go_library(
name = "testpb",
srcs = [
"interceptor_suite.go",
"pingservice.go",
"test.manual_validator.pb.go",
"test.pb.go",
"test_grpc.pb.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/grpc/retry/testpb",
visibility = ["//:__subpackages__"],
deps = [
"@com_github_stretchr_testify//require",
"@com_github_stretchr_testify//suite",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/insecure",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//runtime/protoimpl",
],
)

go_test(
name = "testpb_test",
srcs = ["pingservice_test.go"],
embed = [":testpb"],
deps = [
"@com_github_stretchr_testify//require",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials/insecure",
"@org_golang_google_grpc//status",
],
)
233 changes: 233 additions & 0 deletions internal/grpc/retry/testpb/interceptor_suite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

package testpb

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"flag"
"math/big"
"net"
"sync"
"time"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)

var (
flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls")

certPEM []byte
keyPEM []byte
)

// InterceptorTestSuite is a testify/Suite that starts a gRPC PingService server and a client.
type InterceptorTestSuite struct {
suite.Suite

TestService TestServiceServer
ServerOpts []grpc.ServerOption
ClientOpts []grpc.DialOption

serverAddr string
ServerListener net.Listener
Server *grpc.Server
clientConn *grpc.ClientConn
Client TestServiceClient

restartServerWithDelayedStart chan time.Duration
serverRunning chan bool

cancels []context.CancelFunc
}

func (s *InterceptorTestSuite) SetupSuite() {
s.restartServerWithDelayedStart = make(chan time.Duration)
s.serverRunning = make(chan bool)

s.serverAddr = "127.0.0.1:0"
var err error
certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"}) // CI:LOCALHOST_OK
require.NoError(s.T(), err, "unable to generate test certificate/key")

go func() {
for {
var err error
s.ServerListener, err = net.Listen("tcp", s.serverAddr)
s.serverAddr = s.ServerListener.Addr().String()
require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
if *flagTls {
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(s.T(), err, "unable to load test TLS certificate")
creds := credentials.NewServerTLSFromCert(&cert)
s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds))
}
// This is the point where we hook up the interceptor.
s.Server = grpc.NewServer(s.ServerOpts...)
// Create a service if the instantiator hasn't provided one.
if s.TestService == nil {
s.TestService = &TestPingService{}
}
RegisterTestServiceServer(s.Server, s.TestService)

w := sync.WaitGroup{}
w.Add(1)
go func() {
_ = s.Server.Serve(s.ServerListener)
w.Done()
}()
if s.Client == nil {
s.Client = s.NewClient(s.ClientOpts...)
}

s.serverRunning <- true

d := <-s.restartServerWithDelayedStart
s.Server.Stop()
time.Sleep(d)
w.Wait()
}
}()

select {
case <-s.serverRunning:
case <-time.After(2 * time.Second):
s.T().Fatal("server failed to start before deadline")
}
}

func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool {
s.restartServerWithDelayedStart <- delayedStart
time.Sleep(10 * time.Millisecond)
return s.serverRunning
}

func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) TestServiceClient {
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it.
newDialOpts := append(dialOpts, grpc.WithBlock())
var err error
if *flagTls {
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(certPEM) {
s.T().Fatal("failed to append certificate")
}
creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp}) // CI:LOCALHOST_OK
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds))
} else {
newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
//lint:ignore SA1019 This is a vendored package, so we shouldn't be modifying it.
s.clientConn, err = grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...)
require.NoError(s.T(), err, "must not error on client Dial")
return NewTestServiceClient(s.clientConn)
}

func (s *InterceptorTestSuite) ServerAddr() string {
return s.serverAddr
}

type ctxTestNumber struct{}

var (
ctxTestNumberKey = &ctxTestNumber{}
zero = 0
)

func ExtractCtxTestNumber(ctx context.Context) *int {
if v, ok := ctx.Value(ctxTestNumberKey).(*int); ok {
return v
}
return &zero
}

// UnaryServerInterceptor returns a new unary server interceptors that adds query information logging.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
// newCtx := newContext(ctx, log, opts)
newCtx := ctx
resp, err := handler(newCtx, req)
return resp, err
}
}

func (s *InterceptorTestSuite) SimpleCtx() context.Context {
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
ctx = context.WithValue(ctx, ctxTestNumberKey, 1)
s.cancels = append(s.cancels, cancel)
return ctx
}

func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context {
ctx, cancel := context.WithDeadline(context.TODO(), deadline)
s.cancels = append(s.cancels, cancel)
return ctx
}

func (s *InterceptorTestSuite) TearDownSuite() {
time.Sleep(10 * time.Millisecond)
if s.ServerListener != nil {
s.Server.GracefulStop()
s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr())
_ = s.ServerListener.Close()
}
if s.clientConn != nil {
_ = s.clientConn.Close()
}
for _, c := range s.cancels {
c()
}
}

// generateCertAndKey copied from https://github.com/johanbrandhorst/certify/blob/master/issuers/vault/vault_suite_test.go#L255
// with minor modifications.
func generateCertAndKey(san []string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "example.com",
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: san,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
if err != nil {
return nil, nil, err
}
certOut := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
keyOut := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})

return certOut, keyOut, nil
}
82 changes: 82 additions & 0 deletions internal/grpc/retry/testpb/pingservice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

/*
Package `grpc_testing` provides helper functions for testing validators in this package.
*/

package testpb

import (
"context"
"io"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
// ListResponseCount is the expected number of responses to PingList
ListResponseCount = 100
)

var TestServiceFullName = _TestService_serviceDesc.ServiceName

// Interface implementation assert.
var _ TestServiceServer = &TestPingService{}

type TestPingService struct {
UnimplementedTestServiceServer
}

func (s *TestPingService) PingEmpty(_ context.Context, _ *PingEmptyRequest) (*PingEmptyResponse, error) {
return &PingEmptyResponse{}, nil
}

func (s *TestPingService) Ping(ctx context.Context, ping *PingRequest) (*PingResponse, error) {
// Modify the ctx value to verify the logger sees the value updated from the initial value
n := ExtractCtxTestNumber(ctx)
if n != nil {
*n = 42
}
// Send user trailers and headers.
return &PingResponse{Value: ping.Value, Counter: 0}, nil
}

func (s *TestPingService) PingError(_ context.Context, ping *PingErrorRequest) (*PingErrorResponse, error) {
code := codes.Code(ping.ErrorCodeReturned)
return nil, status.Error(code, "Userspace error")
}

func (s *TestPingService) PingList(ping *PingListRequest, stream TestService_PingListServer) error {
if ping.ErrorCodeReturned != 0 {
return status.Error(codes.Code(ping.ErrorCodeReturned), "foobar")
}

// Send user trailers and headers.
for i := 0; i < ListResponseCount; i++ {
if err := stream.Send(&PingListResponse{Value: ping.Value, Counter: int32(i)}); err != nil {
return err
}
}
return nil
}

func (s *TestPingService) PingStream(stream TestService_PingStreamServer) error {
count := 0
for {
ping, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return err
}
if err := stream.Send(&PingStreamResponse{Value: ping.Value, Counter: int32(count)}); err != nil {
return err
}

count += 1
}
return nil
}
Loading
Loading