Skip to content

Commit

Permalink
transport: refactor to split ClientStream from ServerStream from comm…
Browse files Browse the repository at this point in the history
…on Stream functionality (#7802)
  • Loading branch information
dfawley authored Nov 4, 2024
1 parent 70e8931 commit 2a18bfc
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 340 deletions.
115 changes: 115 additions & 0 deletions internal/transport/client_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package transport

import (
"sync/atomic"

"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// ClientStream implements streaming functionality for a gRPC client.
type ClientStream struct {
*Stream // Embed for common stream functionality.

ct ClientTransport
done chan struct{} // closed at the end of stream to unblock writers.
doneFunc func() // invoked at the end of stream.

headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value).
headerValid bool
header metadata.MD // the received header metadata
noHeaders bool // set if the client never received headers (set only after the stream is done).

bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream

status *status.Status // the status error received from the server
}

// BytesReceived indicates whether any bytes have been received on this stream.
func (s *ClientStream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}

// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *ClientStream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}

func (s *ClientStream) waitOnHeader() {
select {
case <-s.ctx.Done():
// Close the stream to prevent headers/trailers from changing after
// this function returns.
s.ct.CloseStream(s, ContextErr(s.ctx.Err()))
// headerChan could possibly not be closed yet if closeStream raced
// with operateHeaders; wait until it is closed explicitly here.
<-s.headerChan
case <-s.headerChan:
}
}

// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *ClientStream) RecvCompress() string {
s.waitOnHeader()
return s.recvCompress
}

// Done returns a channel which is closed when it receives the final status
// from the server.
func (s *ClientStream) Done() <-chan struct{} {
return s.done
}

// Header returns the header metadata of the stream. Acquires the key-value
// pairs of header metadata once it is available. It blocks until i) the
// metadata is ready or ii) there is no header metadata or iii) the stream is
// canceled/expired.
func (s *ClientStream) Header() (metadata.MD, error) {
s.waitOnHeader()

if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}

return s.header.Copy(), nil
}

// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil.
func (s *ClientStream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}

// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, after Done() is closed.
func (s *ClientStream) Status() *status.Status {
return s.status
}
32 changes: 17 additions & 15 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
}
}

func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
func (ht *serverHandlerTransport) WriteStatus(s *ServerStream, st *status.Status) error {
ht.writeStatusMu.Lock()
defer ht.writeStatusMu.Unlock()

Expand Down Expand Up @@ -289,14 +289,14 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro

// writePendingHeaders sets common and custom headers on the first
// write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) {
func (ht *serverHandlerTransport) writePendingHeaders(s *ServerStream) {
ht.writeCommonHeaders(s)
ht.writeCustomHeaders(s)
}

// writeCommonHeaders sets common headers on the first write
// call (Write, WriteHeader, or WriteStatus).
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
func (ht *serverHandlerTransport) writeCommonHeaders(s *ServerStream) {
h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", ht.contentType)
Expand All @@ -317,7 +317,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {

// writeCustomHeaders sets custom headers set on the stream via SetHeader
// on the first write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
func (ht *serverHandlerTransport) writeCustomHeaders(s *ServerStream) {
h := ht.rw.Header()

s.hdrMu.Lock()
Expand All @@ -333,7 +333,7 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
s.hdrMu.Unlock()
}

func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSlice, _ *Options) error {
func (ht *serverHandlerTransport) Write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *Options) error {
// Always take a reference because otherwise there is no guarantee the data will
// be available after this function returns. This is what callers to Write
// expect.
Expand All @@ -357,7 +357,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSl
return nil
}

func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
func (ht *serverHandlerTransport) WriteHeader(s *ServerStream, md metadata.MD) error {
if err := s.SetHeader(md); err != nil {
return err
}
Expand Down Expand Up @@ -385,7 +385,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}

func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var cancel context.CancelFunc
if ht.timeoutSet {
Expand All @@ -408,16 +408,18 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream

ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req
s := &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
s := &ServerStream{
Stream: &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
buf: newRecvBuffer(),
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
s.trReader = &transportReader{
Expand Down
24 changes: 12 additions & 12 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {

func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t)
handleStream := func(s *Stream) {
handleStream := func(s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
Expand Down Expand Up @@ -313,7 +313,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
st.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(s) },
context.Background(), func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -342,11 +342,11 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)

handleStream := func(s *Stream) {
handleStream := func(s *ServerStream) {
st.ht.WriteStatus(s, status.New(statusCode, msg))
}
st.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(s) },
context.Background(), func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -379,7 +379,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
if err != nil {
t.Fatal(err)
}
runStream := func(s *Stream) {
runStream := func(s *ServerStream) {
defer bodyw.Close()
select {
case <-s.ctx.Done():
Expand All @@ -395,7 +395,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
}
ht.HandleStreams(
context.Background(), func(s *Stream) { go runStream(s) },
context.Background(), func(s *ServerStream) { go runStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand All @@ -412,7 +412,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
// concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
Expand All @@ -433,7 +433,7 @@ func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
// following "WriteStatus" does not panic writing to closed "writes" channel.
func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
Expand All @@ -444,10 +444,10 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
})
}

func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(st, s) },
context.Background(), func(s *ServerStream) { go handleStream(st, s) },
)
}

Expand Down Expand Up @@ -476,11 +476,11 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
}

hst := newHandleStreamTest(t)
handleStream := func(s *Stream) {
handleStream := func(s *ServerStream) {
hst.ht.WriteStatus(s, st)
}
hst.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(s) },
context.Background(), func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down
Loading

0 comments on commit 2a18bfc

Please sign in to comment.