From e55b2616a086b5ef126178efab8a3f4d6884f4d4 Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Wed, 11 Sep 2024 13:22:10 +0800 Subject: [PATCH] WIP(x/net/http/client): Implement BodyChunk --- x/net/http/_demo/chunked/chunked.go | 29 + x/net/http/_demo/get/get.go | 6 +- x/net/http/_demo/headers/headers.go | 6 +- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 6 +- x/net/http/_demo/post/post.go | 5 + x/net/http/_demo/redirect/redirect.go | 6 +- x/net/http/_demo/reuseConn/reuseConn.go | 12 +- x/net/http/_demo/server/chunkedServer.go | 42 + x/net/http/_demo/upload/upload.go | 8 +- x/net/http/bodyChunk.go | 104 ++ x/net/http/client.go | 3 +- x/net/http/header.go | 70 +- x/net/http/request.go | 11 +- x/net/http/response.go | 228 +++- x/net/http/transfer.go | 57 +- x/net/http/transport.go | 1084 +++++++---------- 16 files changed, 893 insertions(+), 784 deletions(-) create mode 100644 x/net/http/_demo/chunked/chunked.go create mode 100644 x/net/http/_demo/server/chunkedServer.go create mode 100644 x/net/http/bodyChunk.go diff --git a/x/net/http/_demo/chunked/chunked.go b/x/net/http/_demo/chunked/chunked.go new file mode 100644 index 0000000..7b33c0c --- /dev/null +++ b/x/net/http/_demo/chunked/chunked.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + resp, err := http.Get("http://localhost:8080/chunked") + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) +} diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 6e91bd4..392cc72 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -15,7 +15,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index 5538923..41cc15f 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -38,7 +38,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { println(err.Error()) diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 5662251..eff95fc 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -22,7 +22,11 @@ func main() { defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/post/post.go b/x/net/http/_demo/post/post.go index fd756b3..b700028 100644 --- a/x/net/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -17,6 +17,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status) + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go index f189255..3d40f3b 100644 --- a/x/net/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -16,7 +16,11 @@ func main() { defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/reuseConn/reuseConn.go b/x/net/http/_demo/reuseConn/reuseConn.go index bb460ce..bccfe9d 100644 --- a/x/net/http/_demo/reuseConn/reuseConn.go +++ b/x/net/http/_demo/reuseConn/reuseConn.go @@ -15,7 +15,11 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) @@ -31,7 +35,11 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err = io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/server/chunkedServer.go b/x/net/http/_demo/server/chunkedServer.go new file mode 100644 index 0000000..b79ad60 --- /dev/null +++ b/x/net/http/_demo/server/chunkedServer.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "net/http" +) + +func chunkedHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Content-Type", "text/plain") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + sentence := "This is a chunked encoded response. It will be sent in multiple parts. Note the delay between each section." + + words := []string{} + start := 0 + for i, r := range sentence { + if r == '。' || r == ',' || i == len(sentence)-1 { + words = append(words, sentence[start:i+1]) + start = i + 1 + } + } + + for _, word := range words { + fmt.Fprintf(w, "%s", word) + flusher.Flush() + } +} + +func main() { + http.HandleFunc("/chunked", chunkedHandler) + fmt.Println("Starting server on :8080") + err := http.ListenAndServe(":8080", nil) + if err != nil { + fmt.Printf("Error starting server: %s\n", err) + } +} \ No newline at end of file diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index fe7256b..b5baffa 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -11,7 +11,7 @@ import ( func main() { url := "http://httpbin.org/post" //url := "http://localhost:8080" - filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path + filePath := "/Users/spongehah/Documents/code/GOPATH/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path //filePath := "/Users/spongehah/Downloads/xiaoshuo.txt" // Replace with your file path file, err := os.Open(filePath) @@ -36,7 +36,11 @@ func main() { } defer resp.Body.Close() fmt.Println("Status:", resp.Status) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } respBody, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/bodyChunk.go b/x/net/http/bodyChunk.go new file mode 100644 index 0000000..c1d1072 --- /dev/null +++ b/x/net/http/bodyChunk.go @@ -0,0 +1,104 @@ +package http + +import ( + "errors" + "io" + "sync" + + "github.com/goplus/llgo/c/libuv" +) + +type onceError struct { + sync.Mutex + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} + +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { + return &bodyChunk{ + readCh: make(chan []byte, 1), + done: make(chan struct{}), + asyncHandle: asyncHandle, + } +} + +type bodyChunk struct { + chunk []byte + readCh chan []byte + asyncHandle *libuv.Async + + once sync.Once + done chan struct{} + + rerr onceError +} + +var ( + errClosedBodyChunk = errors.New("bodyChunk: read/write on closed body") +) + +func (bc *bodyChunk) Read(p []byte) (n int, err error) { + for n < len(p) { + if len(bc.chunk) == 0 { + select { + case chunk, ok := <-bc.readCh: + if !ok { + if n > 0 { + return n, nil + } + return 0, bc.readCloseError() + } + bc.chunk = chunk + bc.asyncHandle.Send() + case <-bc.done: + if n > 0 { + return n, nil + } + return 0, io.EOF + } + } + + copied := copy(p[n:], bc.chunk) + n += copied + bc.chunk = bc.chunk[copied:] + } + + return n, nil +} + +func (bc *bodyChunk) Close() error { + return bc.closeRead(nil) +} + +func (bc *bodyChunk) readCloseError() error { + if rerr := bc.rerr.Load(); rerr != nil { + return rerr + } + return errClosedBodyChunk +} + +func (bc *bodyChunk) closeRead(err error) error { + if err == nil { + err = io.EOF + } + bc.rerr.Store(err) + bc.once.Do(func() { + close(bc.done) + }) + //close(bc.done) + return nil +} diff --git a/x/net/http/client.go b/x/net/http/client.go index 002397a..7e26395 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -307,7 +307,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) timeout(send) + // TODO(spongehah) tmp timeout(send) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) req.timeoutch = make(chan struct{}, 1) req.deadline = deadline @@ -490,7 +490,6 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return knownRoundTripperImpl(altRT, req) } return true - // TODO(spongehah) http2 //case *http2Transport, http2noDialH2RoundTripper: // return true } diff --git a/x/net/http/header.go b/x/net/http/header.go index 0d1e2cc..7c95411 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -75,15 +75,6 @@ func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } -// CanonicalHeaderKey returns the canonical format of the -// header key s. The canonicalization converts the first -// letter and any letter following a hyphen to upper case; -// the rest are converted to lowercase. For example, the -// canonical key for "accept-encoding" is "Accept-Encoding". -// If s contains a space or invalid header field bytes, it is -// returned without modifications. -func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } - // Clone returns a copy of h or nil if h is nil. func (h Header) Clone() Header { if h == nil { @@ -111,28 +102,6 @@ func (h Header) Clone() Header { return h2 } -var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") - -type keyValues struct { - key string - values []string -} - -// A headerSorter implements sort.Interface by sorting a []keyValues -// by key. It's used as a pointer, so it can fit in a sort.Interface -// interface value without allocation. -type headerSorter struct { - kvs []keyValues -} - -func (s *headerSorter) Len() int { return len(s.kvs) } -func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } -func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } - -var headerSorterPool = sync.Pool{ - New: func() any { return new(headerSorter) }, -} - // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. @@ -199,6 +168,37 @@ func (h Header) writeSubset(reqHeaders *hyper.Headers, exclude map[string]bool) return nil } +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + // hasToken reports whether token appears with v, ASCII // case-insensitive, with space or comma boundaries. // token must be all lowercase. @@ -251,11 +251,3 @@ func appendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va resp.Header.Add(nameStr, valueStr) return hyper.IterContinue } - -func (resp *Response) PrintHeaders() { - for key, values := range resp.Header { - for _, value := range values { - fmt.Printf("%s: %s\n", key, value) - } - } -} diff --git a/x/net/http/request.go b/x/net/http/request.go index c5146ed..e9279fc 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -294,7 +294,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype //} // Prepare the hyper.Request - hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra) + hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra, taskData.req) if err != nil { return err } @@ -308,7 +308,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype return err } -func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.Request, error) { +func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header, treq *transportRequest) (*hyper.Request, error) { // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -401,11 +401,6 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Process Body,ContentLength,Close,Trailer - //tw, err := newTransferWriter(r) - //if err != nil { - // return err - //} - //err = tw.writeHeader(w, trace) err = r.writeHeader(reqHeaders) if err != nil { return nil, err @@ -433,7 +428,7 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Write body and trailer - err = r.writeBody(hyperReq) + err = r.writeBody(hyperReq, treq) if err != nil { return nil, err } diff --git a/x/net/http/response.go b/x/net/http/response.go index 6ff5b3d..a3a96fc 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -1,10 +1,12 @@ package http import ( + "compress/gzip" + "errors" "fmt" "io" "strconv" - "unsafe" + "sync" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -32,7 +34,198 @@ func (r *Response) closeBody() { } } -func ReadResponse(r *io.PipeReader, req *Request, hyperResp *hyper.Response) (*Response, error) { +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + +func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { + pc := taskData.pc + bodyWritable := r.bodyIsWritable() + hasBody := taskData.req.Method != "HEAD" && r.ContentLength != 0 + + if r.Close || taskData.req.Close || r.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + pc.alive = false + } + + if !hasBody || bodyWritable { + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) + + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() + + if bodyWritable { + pc.closeErr = errCallerOwnsConn + } + + select { + case taskData.resc <- responseAndError{res: r}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return true + } + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + readLoopDefer(pc, false) + return true + } + return false +} + +func (r *Response) wrapRespBody(taskData *taskData) { + body := &bodyEOFSignal{ + body: r.Body, + earlyCloseFn: func() error { + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + if !isEOF { + if cerr := taskData.pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + r.Body = body + // TODO(spongehah) gzip(wrapRespBody) + //if taskData.addedGzip && EqualFold(r.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // r.Body = &gzipReader{body: body} + // r.Header.Del("Content-Encoding") + // r.Header.Del("Content-Length") + // r.ContentLength = -1 + // r.Uncompressed = true + //} +} + +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. fn should return the new error to return from Read or Close. +// +// If earlyCloseFn is non-nil and Close is called before io.EOF is +// seen, earlyCloseFn is called instead of fn, and its return value is +// the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) error // err will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errReadOnClosedResBody + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + err = es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + return es.condfn(err) +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) error { + if es.fn == nil { + return err + } + err = es.fn(err) + es.fn = nil + return err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + _ incomparable + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) + } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + +func ReadResponse(r io.ReadCloser, req *Request, hyperResp *hyper.Response) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), @@ -65,20 +258,6 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } -// appendToResponseBody BodyForeachCallback function: Process the response body -func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - writer := (*io.PipeWriter)(userdata) - bufLen := chunk.Len() - bytes := unsafe.Slice(chunk.Bytes(), bufLen) - _, err := writer.Write(bytes) - if err != nil { - fmt.Println("Error writing to response body:", err) - writer.Close() - return hyper.IterBreak - } - return hyper.IterContinue -} - // RFC 7234, section 5.4: Should treat // // Pragma: no-cache @@ -94,26 +273,9 @@ func fixPragmaCacheControl(header Header) { } } -// Cookies parses and returns the cookies set in the Set-Cookie headers. -func (r *Response) Cookies() []*Cookie { - return readSetCookies(r.Header) -} - // isProtocolSwitchHeader reports whether the request or response header // is for a protocol switch. func isProtocolSwitchHeader(h Header) bool { return h.Get("Upgrade") != "" && HeaderValuesContainsToken(h["Connection"], "Upgrade") } - -// bodyIsWritable reports whether the Body supports writing. The -// Transport returns Writable bodies for 101 Switching Protocols -// responses. -// The Transport uses this method to determine whether a persistent -// connection is done being managed from its perspective. Once we -// return a writable response body to a user, the net/http package is -// done managing that connection. -func (r *Response) bodyIsWritable() bool { - _, ok := r.Body.(io.Writer) - return ok -} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 103200c..818fb3c 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -98,7 +98,7 @@ func (uste *unsupportedTEError) Error() string { } // msg is *Request or *Response. -func readTransfer(msg any, r *io.PipeReader) (err error) { +func readTransfer(msg any, r io.ReadCloser) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -173,19 +173,17 @@ func readTransfer(msg any, r *io.PipeReader) (err error) { if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { t.Body = NoBody } else { - // TODO(spongehah) ChunkReader(readTransfer) - //t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} - t.Body = &body{src: r, hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: r, closer: r, hdr: msg, r: r, closing: t.Close} } case realLength == 0: t.Body = NoBody case realLength > 0: - t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + t.Body = &body{src: io.LimitReader(r, realLength), closer: r, closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) - t.Body = &body{src: r, closing: t.Close} + t.Body = &body{src: r, closer: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) t.Body = NoBody @@ -349,9 +347,9 @@ func fixTrailer(header Header, chunked bool) (Header, error) { // Close ensures that the body has been fully read // and then reads the trailer if necessary. type body struct { - src io.Reader - hdr any // non-nil (Response or Request) value means read trailer - //r *bufio.Reader // underlying wire-format reader for the trailer + src io.Reader + closer io.Closer + hdr any // non-nil (Response or Request) value means read trailer r io.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early @@ -476,6 +474,15 @@ func (b *body) Close() error { _, err = io.Copy(io.Discard, bodyLocked{b}) } b.closed = true + + // Close bodyChunk + if b.closer != nil { + closeErr := b.closer.Close() + if err == nil { + err = closeErr + } + } + return err } @@ -654,26 +661,26 @@ func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) // files (*os.File types) are properly optimized. // // This function is only intended for use in writeBody. -func (req *Request) unwrapBody() io.Reader { - if r, ok := unwrapNopCloser(req.Body); ok { +func (r *Request) unwrapBody() io.Reader { + if r, ok := unwrapNopCloser(r.Body); ok { return r } - if r, ok := req.Body.(*readTrackingBody); ok { + if r, ok := r.Body.(*readTrackingBody); ok { r.didRead = true return r.ReadCloser } - return req.Body + return r.Body } -func (r *Request) writeBody(hyperReq *hyper.Request) error { +func (r *Request) writeBody(hyperReq *hyper.Request, treq *transportRequest) error { if r.Body != nil { var body = r.unwrapBody() hyperReqBody := hyper.NewBody() buf := make([]byte, defaultChunkSize) reqData := &bodyReq{ - body: body, - buf: buf, - closeBody: r.closeBody, + body: body, + buf: buf, + treq: treq, } hyperReqBody.SetUserdata(c.Pointer(reqData)) hyperReqBody.SetDataFunc(setPostData) @@ -683,9 +690,9 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { } type bodyReq struct { - body io.Reader - buf []byte - closeBody func() error + body io.Reader + buf []byte + treq *transportRequest } func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { @@ -694,10 +701,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In if err != nil { if err == io.EOF { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } fmt.Println("error reading request body: ", err) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } if n > 0 { @@ -706,10 +714,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In } if n == 0 { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } - req.closeBody() - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.closeBody() + err = fmt.Errorf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 44d721d..8075133 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -1,7 +1,6 @@ package http import ( - "compress/gzip" "container/list" "context" "errors" @@ -37,7 +36,12 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -const debugSwitch = true + +// Debug switch provided for developers +const ( + debugSwitch = true + debugReadWriteLoop = true +) type Transport struct { idleMu sync.Mutex @@ -46,53 +50,24 @@ type Transport struct { idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns idleLRU connLRU - altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - Proxy func(*Request) (*url.URL, error) + + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns - // DisableKeepAlives, if true, disables HTTP keep-alives and - // will only use the connection to the server for a single - // HTTP request. - // - // This is unrelated to the similarly named TCP keep-alives. - DisableKeepAlives bool - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool + Proxy func(*Request) (*url.URL, error) - // MaxIdleConns controls the maximum number of idle (keep-alive) - // connections across all hosts. Zero means no limit. - MaxIdleConns int + DisableKeepAlives bool + DisableCompression bool - // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) connections to keep per-host. If zero, - // DefaultMaxIdleConnsPerHost is used. + MaxIdleConns int MaxIdleConnsPerHost int - - // MaxConnsPerHost optionally limits the total number of - // connections per host, including connections in the dialing, - // active, and idle states. On limit violation, dials will block. - // - // Zero means no limit. - MaxConnsPerHost int - - // IdleConnTimeout is the maximum amount of time an idle - // (keep-alive) connection will remain idle before closing - // itself. - // Zero means no limit. - IdleConnTimeout time.Duration + MaxConnsPerHost int + IdleConnTimeout time.Duration // libuv and hyper related loopInitOnce sync.Once @@ -516,14 +491,11 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { // useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false - } - return true + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return !(req.URL.Scheme == "https" && req.requiresHTTP1()) } // CancelRequest cancels an in-flight request by closing its connection. @@ -573,6 +545,8 @@ func (t *Transport) closeLocked(err error) { } } +// ---------------------------------------------------------- + func getMilliseconds(deadline time.Time) uint64 { microseconds := deadline.Sub(time.Now()).Microseconds() milliseconds := microseconds / 1e3 @@ -582,15 +556,13 @@ func getMilliseconds(deadline time.Time) uint64 { return uint64(milliseconds) } -// ---------------------------------------------------------- - func (t *Transport) RoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("RoundTrip start") - defer println("RoundTrip end") + println("############### RoundTrip start") + defer println("############### RoundTrip end") } t.loopInitOnce.Do(func() { - println("init loop") + println("############### init loop") t.loop = libuv.LoopNew() t.async = &libuv.Async{} t.exec = hyper.NewExecutor() @@ -620,7 +592,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) if debugSwitch { - println("timer start") + println("############### timer start") } didTimeout = func() bool { return req.timer.GetDueIn() == 0 } stopTimer = func() { @@ -628,7 +600,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Stop() (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) if debugSwitch { - println("timer close") + println("############### timer close") } } } else { @@ -654,8 +626,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { func (t *Transport) doRoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("doRoundTrip start") - defer println("doRoundTrip end") + println("############### doRoundTrip start") + defer println("############### doRoundTrip end") } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() @@ -715,7 +687,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout(t.doRoundTrip) //select { //case <-ctx.Done(): // req.closeBody() @@ -766,7 +737,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) ConnPool(t.doRoundTrip) if http2isNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) @@ -800,8 +770,8 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { if debugSwitch { - println("getConn start") - defer println("getConn end") + println("############### getConn start") + defer println("############### getConn end") } req := treq.Request //trace := treq.trace @@ -824,13 +794,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } }() - // TODO(spongehah) ConnPool(t.getConn) // Queue for idle connection. if delivered := t.queueForIdleConn(w); delivered { pc := w.pc // Trace only for HTTP/1. // HTTP/2 calls trace.GotConn itself. - // TODO(spongehah) trace(t.getConn) //if pc.alt == nil && trace != nil && trace.GotConn != nil { // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) //} @@ -853,28 +821,28 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) //} if w.err != nil { - // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the - // cancellation error (see golang.org/issue/16049). - select { - // TODO(spongehah) timeout(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() - case <-req.timeoutch: - if debugSwitch { - println("getConn: timeoutch") - } - return nil, errors.New("timeout: req.Context().Err()") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err - default: - // return below + return nil, w.err + } + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case <-req.timeoutch: + if debugSwitch { + println("############### getConn: timeoutch") + } + return nil, errors.New("timeout: req.Context().Err()") + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn } + return nil, err + default: + // return below } return w.pc, w.err } @@ -883,8 +851,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { if debugSwitch { - println("queueForDial start") - defer println("queueForDial end") + println("############### queueForDial start") + defer println("############### queueForDial end") } w.beforeDial() @@ -919,13 +887,12 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { if debugSwitch { - println("dialConnFor start") - defer println("dialConnFor end") + println("############### dialConnFor start") + defer println("############### dialConnFor end") } defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) - // TODO(spongehah) ConnPool(t.dialConnFor) delivered := w.tryDeliver(pc, err) // If the connection was successfully established but was not passed to w, // or is a shareable HTTP/2 connection @@ -994,8 +961,8 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { if debugSwitch { - println("dialConn start") - defer println("dialConn end") + println("############### dialConn start") + defer println("############### dialConn end") } select { case <-timeoutch: @@ -1009,7 +976,9 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), alive: true, + chunkAsync: &libuv.Async{}, } + t.loop.Async(pconn.chunkAsync, readyToRead) //trace := httptrace.ContextClientTrace(ctx) //wrapErr := func(err error) error { @@ -1102,6 +1071,21 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // } //} + pconn.closeErr = errReadLoopExiting + pconn.tryPutIdleConn = func() bool { + if err := pconn.t.tryPutIdleConn(pconn); err != nil { + pconn.closeErr = err + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") @@ -1114,8 +1098,8 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * func (t *Transport) dial(addr string) (*connData, error) { if debugSwitch { - println("dial start") - defer println("dial end") + println("############### dial start") + defer println("############### dial end") } host, port, err := net.SplitHostPort(addr) if err != nil { @@ -1150,12 +1134,11 @@ func (t *Transport) dial(addr string) (*connData, error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { if debugSwitch { - println("roundTrip start") - defer println("roundTrip end") + println("############### roundTrip start") + defer println("############### roundTrip end") } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - // TODO(spongehah) ConnPool(pc.roundTrip) pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -1168,40 +1151,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err headerFn(req.extraHeaders()) } - // Ask for a compressed version if the caller didn't set their - // own value for Accept-Encoding. We only attempt to - // uncompress the gzip stream if we were the layer that - // requested it. - requestedGzip := false - // TODO(spongehah) gzip(pc.roundTrip) - //if !pc.t.DisableCompression && - // req.Header.Get("Accept-Encoding") == "" && - // req.Header.Get("Range") == "" && - // req.Method != "HEAD" { - // // Request gzip only, not deflate. Deflate is ambiguous and - // // not as universally supported anyway. - // // See: https://zlib.net/zlib_faq.html#faq39 - // // - // // Note that we don't request this for HEAD requests, - // // due to a bug in nginx: - // // https://trac.nginx.org/nginx/ticket/358 - // // https://golang.org/issue/5522 - // // - // // We don't request gzip if the request is for a range, since - // // auto-decoding a portion of a gzipped document will just fail - // // anyway. See https://golang.org/issue/8923 - // requestedGzip = true - // req.extraHeaders().Set("Accept-Encoding", "gzip") - //} - - // The 100-continue operation in Hyper is handled in the newHyperRequest function. - - // Keep-Alive - if pc.t.DisableKeepAlives && - !req.wantsClose() && - !isProtocolSwitchHeader(req.Header) { - req.extraHeaders().Set("Connection", "close") - } + // Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). + requestedGzip := pc.setExtraHeaders(req) gone := make(chan struct{}, 1) defer close(gone) @@ -1229,9 +1180,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } if pc.client == nil && !pc.isReused() { - println("first") // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) + hyperIo := newHyperIo(pc.conn) // We need an executor generally to poll futures // Prepare client options opts := hyper.NewClientConnOptions() @@ -1243,7 +1193,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // Send the request to readWriteLoop(). pc.t.exec.Push(handshakeTask) } else { - println("second") taskData.taskId = read err = req.write(pc.client, taskData, pc.t.exec) if err != nil { @@ -1264,12 +1213,12 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() if debugSwitch { - println("roundTrip for") + println("############### roundTrip for") } select { case err := <-writeErrCh: if debugSwitch { - println("roundTrip: writeErrch") + println("############### roundTrip: writeErrch") } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) @@ -1278,17 +1227,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return nil, pc.mapRoundTripError(req, startBytesWritten, err) } - //if d := pc.t.ResponseHeaderTimeout; d > 0 { - // if debugRoundTrip { - // //req.logf("starting timer for %v", d) - // } - // timer := time.NewTimer(d) - // defer timer.Stop() // prevent leaks - // respHeaderTimer = timer.C - //} case <-pcClosed: if debugSwitch { - println("roundTrip: pcClosed") + println("############### roundTrip: pcClosed") } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { @@ -1297,7 +1238,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err //case <-respHeaderTimer: case re := <-resc: if debugSwitch { - println("roundTrip: resc") + println("############### roundTrip: resc") } if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) @@ -1306,7 +1247,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - // TODO(spongehah) timeout(pc.roundTrip) //case <-cancelChan: // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) // cancelChan = nil @@ -1316,7 +1256,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // ctxDoneChan = nil case <-timeoutch: if debugSwitch { - println("roundTrip: timeoutch") + println("############### roundTrip: timeoutch") } canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) timeoutch = nil @@ -1330,361 +1270,232 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err func readWriteLoop(checker *libuv.Check) { t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) - // Read this once, before loop starts. (to avoid races in tests) - //testHookMu.Lock() - //testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - //testHookMu.Unlock() - - const debugReadWriteLoop = true // Debug switch provided for developers - - // The polling state machine! - // Poll all ready tasks and act on them... - for { - task := t.exec.Poll() + // The polling state machine! Poll all ready tasks and act on them... + task := t.exec.Poll() + for task != nil { if debugSwitch { - println("polling") - } - if task == nil { - return - } - taskData := (*taskData)(task.Userdata()) - var taskId taskId - if taskData != nil { - taskId = taskData.taskId - } else { - taskId = notSet + println("############### polling") } + t.handleTask(task) + task = t.exec.Poll() + } +} + +func (t *Transport) handleTask(task *hyper.Task) { + taskData := (*taskData)(task.Userdata()) + if taskData == nil { + // A background task for hyper_client completed... + task.Free() + return + } + var err error + pc := taskData.pc + // If original taskId is set, we need to check it + err = checkTaskType(task, taskData) + if err != nil { + readLoopDefer(pc, true) + return + } + switch taskData.taskId { + case handshake: if debugReadWriteLoop { - println("taskId: ", taskId) + println("############### write") } - switch taskId { - case handshake: - if debugReadWriteLoop { - println("write") - } - - err := checkTaskType(task, handshake) - if err != nil { - taskData.writeErrCh <- err - task.Free() - continue - } - - pc := taskData.pc - select { - case <-pc.closech: - task.Free() - continue - default: - } - pc.client = (*hyper.ClientConn)(task.Value()) + // Check if the connection is closed + select { + case <-pc.closech: task.Free() + return + default: + } - // TODO(spongehah) Proxy(writeLoop) - taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, t.exec) + pc.client = (*hyper.ClientConn)(task.Value()) + task.Free() - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } - - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } - - pc := taskData.pc - - err := checkTaskType(task, read) - if bre, ok := err.(requestBodyReadError); ok { - err = bre.error - // Errors reading from the user's - // Request.Body are high priority. - // Set it here before sending on the - // channels below or calling - // pc.close() which tears down - // connections and causes other - // errors. - taskData.req.setError(err) - } - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } + // TODO(spongehah) Proxy(writeLoop) + taskData.taskId = read + err = taskData.req.Request.write(pc.client, taskData, t.exec) - if pc.closeErr == nil { - pc.closeErr = errReadLoopExiting - } - // TODO(spongehah) ConnPool(readWriteLoop) - if pc.tryPutIdleConn == nil { - pc.tryPutIdleConn = func() bool { - if err := pc.t.tryPutIdleConn(pc); err != nil { - pc.closeErr = err - // TODO(spongehah) trace(readWriteLoop) - //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - //} - return false - } - //if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - //} - return true - } - } + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + pc.close(err) + return + } - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() + if debugReadWriteLoop { + println("############### write end") + } + case read: + if debugReadWriteLoop { + println("############### read") + } - pc.mu.Lock() - if pc.numExpectedResponses == 0 { - pc.readLoopPeekFailLocked(hyperResp, err) - pc.mu.Unlock() + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() - // defer - readLoopDefer(pc, t) - continue - } + //pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() + readLoopDefer(pc, true) + return + } + //pc.mu.Unlock() - //trace := httptrace.ContextClientTrace(rc.req.Context()) - - var resp *Response - var respBody *hyper.Body - if err == nil { - var pr *io.PipeReader - pr, taskData.bodyWriter = io.Pipe() - resp, err = ReadResponse(pr, taskData.req.Request, hyperResp) - respBody = hyperResp.Body() - } else { - err = transportReadFromServerError{err} - pc.closeErr = err - } + var resp *Response + if err == nil { + pc.chunkAsync.SetData(c.Pointer(taskData)) + bc := newBodyChunk(pc.chunkAsync) + pc.bodyChunk = bc + resp, err = ReadResponse(bc, taskData.req.Request, hyperResp) + taskData.hyperBody = hyperResp.Body() + } else { + err = transportReadFromServerError{err} + pc.closeErr = err + } - // No longer need the response - hyperResp.Free() + // No longer need the response + hyperResp.Free() - if err != nil { - select { - case taskData.resc <- responseAndError{err: err}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // defer - readLoopDefer(pc, t) - continue + if err != nil { + select { + case taskData.resc <- responseAndError{err: err}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return } + readLoopDefer(pc, true) + return + } - pc.mu.Lock() - pc.numExpectedResponses-- - pc.mu.Unlock() - - bodyWritable := resp.bodyIsWritable() - hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 + dataTask := taskData.hyperBody.Data() + taskData.taskId = readBodyChunk + dataTask.SetUserdata(c.Pointer(taskData)) + t.exec.Push(dataTask) - if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { - // Don't do keep-alive on error if either party requested a close - // or we get an unexpected informational (1xx) response. - // StatusCode 100 is already handled above. - pc.alive = false - } + if !taskData.req.deadline.IsZero() { + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + } - if !hasBody || bodyWritable { - replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) - - // TODO(spongehah) ConnPool(readWriteLoop) - // Put the idle conn back into the pool before we send the response - // so if they process it quickly and make another request, they'll - // get this same conn. But we use the unbuffered channel 'rc' - // to guarantee that persistConn.roundTrip got out of its select - // potentially waiting for this persistConn to close. - pc.alive = pc.alive && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && pc.tryPutIdleConn() - - if bodyWritable { - pc.closeErr = errCallerOwnsConn - } + //pc.mu.Lock() + pc.numExpectedResponses-- + //pc.mu.Unlock() - select { - case taskData.resc <- responseAndError{res: resp}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // Now that they've read from the unbuffered channel, they're safely - // out of the select that also waits on this goroutine to die, so - // we're allowed to exit now if needed (if alive is false) - //testHookReadLoopBeforeNextRead() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } - continue - } + needContinue := resp.checkRespBody(taskData) + if needContinue { + return + } - body := &bodyEOFSignal{ - body: resp.Body, - earlyCloseFn: func() error { - taskData.bodyWriter.Close() - return nil - }, - fn: func(err error) error { - isEOF := err == io.EOF - if !isEOF { - if cerr := pc.canceled(); cerr != nil { - return cerr - } - } - return err - }, - } - resp.Body = body - - // TODO(spongehah) gzip(pc.readWriteLoop) - //if taskData.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - // println("gzip reader") - // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") - // resp.Header.Del("Content-Length") - // resp.ContentLength = -1 - // resp.Uncompressed = true - //} + resp.wrapRespBody(taskData) - bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(taskData.bodyWriter)) - taskData.taskId = readDone - bodyForeachTask.SetUserdata(c.Pointer(taskData)) - t.exec.Push(bodyForeachTask) - if taskData.req.timer != nil { - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData - } + // FIXME: Waiting for the channel bug to be fixed + //select { + //case taskData.resc <- responseAndError{res: resp}: + //case <-taskData.callerGone: + // // defer + // readLoopDefer(pc, true) + // return + //} + select { + case <-taskData.callerGone: + readLoopDefer(pc, true) + return + default: + } + taskData.resc <- responseAndError{res: resp} - // TODO(spongehah) select blocking(readWriteLoop) - //select { - //case taskData.resc <- responseAndError{res: resp}: - //case <-taskData.callerGone: - // // defer - // readLoopDefer(pc, t) - // continue - //} - select { - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - default: - } - taskData.resc <- responseAndError{res: resp} + if debugReadWriteLoop { + println("############### read end") + } + case readBodyChunk: + if debugReadWriteLoop { + println("############### readBodyChunk") + } + taskType := task.Type() + if taskType == hyper.TaskBuf { + chunk := (*hyper.Buf)(task.Value()) + chunkLen := chunk.Len() + bytes := unsafe.Slice(chunk.Bytes(), chunkLen) + // Free chunk and task + chunk.Free() + task.Free() + // Write to the channel + pc.bodyChunk.readCh <- bytes if debugReadWriteLoop { - println("read end") - } - case readDone: - // A background task of reading the response body is completed - if debugReadWriteLoop { - println("readDone") - } - if taskData.bodyWriter != nil { - taskData.bodyWriter.Close() + println("############### readBodyChunk end [buf]") } - checkTaskType(task, readDone) - - bodyEOF := task.Type() == hyper.TaskEmpty - // free the task - task.Free() - - pc := taskData.pc - - replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool - // TODO(spongehah) ConnPool(readWriteLoop) - pc.alive = pc.alive && - bodyEOF && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // bodyEOF && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - // TODO(spongehah) timeout(t.readWriteLoop) - //case <-rw.rc.req.Cancel: - // pc.alive = false - // pc.t.CancelRequest(rw.rc.req) - //case <-rw.rc.req.Context().Done(): - // pc.alive = false - // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - //case <-pc.closech: - // pc.alive = false - //} + return + } - //select { - //case <-taskData.req.timeoutch: - // continue - //case <-pc.closech: - // pc.alive = false - //default: - //} + // taskType == taskEmpty (check in checkTaskType) + task.Free() + taskData.hyperBody.Free() + taskData.hyperBody = nil + pc.bodyChunk.Close() + replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } + readLoopDefer(pc, false) - //testHookReadLoopBeforeNextRead() - if debugReadWriteLoop { - println("readDone end") - } - case notSet: - // A background task for hyper_client completed... - task.Free() + if debugReadWriteLoop { + println("############### readBodyChunk end [empty]") } } } -func readLoopDefer(pc *persistConn, t *Transport) { +func readyToRead(aysnc *libuv.Async) { + println("############### AsyncCb: readyToRead") + taskData := (*taskData)(aysnc.GetData()) + dataTask := taskData.hyperBody.Data() + dataTask.SetUserdata(c.Pointer(taskData)) + taskData.pc.t.exec.Push(dataTask) +} + +// readLoopDefer Replace the defer function of readLoop in stdlib +func readLoopDefer(pc *persistConn, force bool) { + if pc.alive == true && !force { + return + } pc.close(pc.closeErr) - // TODO(spongehah) ConnPool(readLoopDefer) - t.removeIdleConn(pc) + pc.t.removeIdleConn(pc) } // ---------------------------------------------------------- +type connData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + nwrite int64 // bytes written(Replaced from persistConn's nwrite) + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + type taskData struct { taskId taskId - bodyWriter *io.PipeWriter req *transportRequest pc *persistConn addedGzip bool writeErrCh chan error callerGone chan struct{} resc chan responseAndError + hyperBody *hyper.Body } -type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr - nwrite int64 // bytes written(Replaced from persistConn's nwrite) - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int + +const ( + handshake taskId = iota + 1 + read + readBodyChunk +) func (conn *connData) Close() error { if conn == nil { @@ -1709,8 +1520,8 @@ func (conn *connData) Close() error { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { if debugSwitch { - println("connect start") - defer println("connect end") + println("############### connect start") + defer println("############### connect end") } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) @@ -1723,8 +1534,6 @@ func onConnect(req *libuv.Connect, status c.Int) { // allocBuffer allocates a buffer for reading from a socket func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { - //conn := (*ConnData)(handle.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data conn := (*connData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) @@ -1738,31 +1547,21 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { // onRead is the libuv callback for reading from a socket // This callback function is called when data is available to be read func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { - // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - - // If data was read (nread > 0) if nread > 0 { - // Update the amount of filled buffer conn.ReadBufFilled += uintptr(nread) } - // If there's a pending read waker if conn.ReadWaker != nil { // Wake up the pending read operation of Hyper conn.ReadWaker.Wake() - // Clear the waker reference conn.ReadWaker = nil } } // readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - - // If there's data in the buffer if conn.ReadBufFilled > 0 { - // Calculate how much data to copy (minimum of filled amount and requested amount) var toCopy uintptr if bufLen < conn.ReadBufFilled { toCopy = bufLen @@ -1775,71 +1574,52 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) // Update the amount of filled buffer conn.ReadBufFilled -= toCopy - // Return the number of bytes copied return toCopy } - // If no data in buffer, set up a waker to wait for more data - // Free the old waker if it exists if conn.ReadWaker != nil { conn.ReadWaker.Free() } - // Create a new waker conn.ReadWaker = ctx.Waker() - // Return HYPER_IO_PENDING to indicate operation is pending, waiting for more data return hyper.IoPending } // onWrite is the libuv callback for writing to a socket // Callback function called after a write operation completes func onWrite(req *libuv.Write, status c.Int) { - // Get the connection data associated with the write request conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - - // If there's a pending write waker if conn.WriteWaker != nil { // Wake up the pending write operation conn.WriteWaker.Wake() - // Clear the waker reference conn.WriteWaker = nil } } // writeCallBack write callback function for Hyper library func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) req := &libuv.Write{} - // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - // Perform the asynchronous write operation ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) - // If the write operation was successfully initiated if ret >= 0 { conn.nwrite += int64(bufLen) - // Return the number of bytes to be written return bufLen } - // If the write operation can't complete immediately, set up a waker to wait for completion if conn.WriteWaker != nil { - // Free the old waker if it exists conn.WriteWaker.Free() } - // Create a new waker conn.WriteWaker = ctx.Waker() - // Return HYPER_IO_PENDING to indicate operation is pending, waiting for write to complete return hyper.IoPending } // onTimeout is the libuv callback for a timeout func onTimeout(timer *libuv.Timer) { if debugSwitch { - println("onTimeout start") - defer println("onTimeout end") + println("############### onTimeout start") + defer println("############### onTimeout end") } data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) close(data.timeoutch) @@ -1850,13 +1630,12 @@ func onTimeout(timer *libuv.Timer) { pc := taskData.pc pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) - // defer - readLoopDefer(pc, pc.t) + readLoopDefer(pc, true) } } -// newIoWithConnReadWrite creates a new IO with read and write callbacks -func newIoWithConnReadWrite(connData *connData) *hyper.Io { +// newHyperIo creates a new IO with read and write callbacks +func newHyperIo(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) hyperIo.SetRead(readCallBack) @@ -1864,79 +1643,98 @@ func newIoWithConnReadWrite(connData *connData) *hyper.Io { return hyperIo } -// taskId The unique identifier of the next task polled from the executor -type taskId c.Int - -const ( - notSet taskId = iota - handshake - read - readDone -) - // checkTaskType checks the task type -func checkTaskType(task *hyper.Task, curTaskId taskId) error { - switch curTaskId { - case handshake: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::handshake]handshake task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("[readWriteLoop::handshake]unexpected task type\n") - } - return nil - case read: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::read]write task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("[readWriteLoop::read]unexpected task type\n")) - return errors.New("[readWriteLoop::read]unexpected task type\n") +func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { + curTaskId := taskData.taskId + taskType := task.Type() + if taskType == hyper.TaskError { + err = fail((*hyper.Error)(task.Value()), curTaskId) + } + if err == nil { + switch curTaskId { + case handshake: + if taskType != hyper.TaskClientConn { + err = errors.New("Unexpected hyper task type: expected to be TaskClientConn, actual is " + strTaskType(taskType)) + } + case read: + if taskType != hyper.TaskResponse { + err = errors.New("Unexpected hyper task type: expected to be TaskResponse, actual is " + strTaskType(taskType)) + } + case readBodyChunk: + if taskType != hyper.TaskBuf && taskType != hyper.TaskEmpty { + err = errors.New("Unexpected hyper task type: expected to be TaskBuf / TaskEmpty, actual is " + strTaskType(taskType)) + } } - return nil - case readDone: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::readDone]read response body error!\n") - return fail((*hyper.Error)(task.Value())) + } + if err != nil { + task.Free() + if curTaskId == handshake || curTaskId == read { + taskData.writeErrCh <- err + taskData.pc.close(err) } - return nil - case notSet: + taskData.pc.alive = false } - return errors.New("[readWriteLoop]unexpected task type\n") + return } // fail prints the error details and panics -func fail(err *hyper.Error) error { +func fail(err *hyper.Error, taskId taskId) error { if err != nil { - c.Printf(c.Str("[readWriteLoop]error code: %d\n"), err.Code()) // grab the error details var errBuf [256]c.Char errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - - c.Printf(c.Str("[readWriteLoop]details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + errDetails := unsafe.SliceData(errBuf[:errLen]) + details := c.GoString(errDetails) // clean up the error err.Free() - return fmt.Errorf("[readWriteLoop]hyper request error, error code: %d\n", int(err.Code())) + return fmt.Errorf("hyper request error, taskId: %s, details: %s\n", strTaskId(taskId), details) } return nil } +func strTaskType(taskType hyper.TaskReturnType) string { + switch taskType { + case hyper.TaskClientConn: + return "TaskClientConn" + case hyper.TaskResponse: + return "TaskResponse" + case hyper.TaskBuf: + return "TaskBuf" + case hyper.TaskEmpty: + return "TaskEmpty" + case hyper.TaskError: + return "TaskError" + default: + return "Unknown" + } +} + +func strTaskId(taskId taskId) string { + switch taskId { + case handshake: + return "handshake" + case read: + return "read" + case readBodyChunk: + return "readBodyChunk" + default: + return "notSet" + } +} + // ---------------------------------------------------------- // error values for debugging and testing, not seen by users. var ( - errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") - errConnBroken = errors.New("http: putIdleConn: connection is in bad state") - errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") - errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") - errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") - errCloseIdleConns = errors.New("http: CloseIdleConnections called") - errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") - errReadWriteLoopExiting = errors.New("http: Transport.readWriteLoop exiting") - errIdleConnTimeout = errors.New("http: idle connection timeout") + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -1971,14 +1769,6 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } -// fakeLocker is a sync.Locker which does nothing. It's used to guard -// test-only fields when not under test, to avoid runtime atomic -// overhead. -type fakeLocker struct{} - -func (fakeLocker) Lock() {} -func (fakeLocker) Unlock() {} - // nothingWrittenError wraps a write errors which ended up writing zero bytes. type nothingWrittenError struct { error @@ -2014,9 +1804,6 @@ var ( testHookRoundTripRetried = nop testHookPrePendingDial = nop testHookPostPendingDial = nop - - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead = nop ) var portMap = map[string]string{ @@ -2076,10 +1863,34 @@ type persistConn struct { mutateHeaderFunc func(Header) // other - alive bool // Replace the alive in readLoop - closeErr error // Replace the closeErr in readLoop - tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop - client *hyper.ClientConn + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop + tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop + client *hyper.ClientConn // http long connection client handle + bodyChunk *bodyChunk // Implement non-blocking consumption of each responseBody chunk + chunkAsync *libuv.Async // Notifying that the received chunk has been read +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.closeIdle = true // close newly idle connections + t.idleLRU = connLRU{} + t.idleMu.Unlock() + for _, conns := range m { + for _, pconn := range conns { + pconn.close(errCloseIdleConns) + } + } + //if t2 := t.h2transport; t2 != nil { + // t2.CloseIdleConnections() + //} } func (pc *persistConn) cancelRequest(err error) { @@ -2110,7 +1921,7 @@ func (pc *persistConn) markReused() { func (pc *persistConn) closeLocked(err error) { if debugSwitch { - println("pc closed") + println("############### pc closed") } if err == nil { panic("nil error") @@ -2128,6 +1939,7 @@ func (pc *persistConn) closeLocked(err error) { close(pc.closech) close(pc.writeLoopDone) pc.client.Free() + pc.chunkAsync.Close(nil) } } pc.mutateHeaderFunc = nil @@ -2256,13 +2068,11 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // the 1st response byte from the server. return true } - if err == errServerClosedIdle { - // The server replied with io.EOF while we were trying to - // read the response. Probably an unfortunately keep-alive - // timeout, just as the client was writing a request. - return true - } - return false // conservatively + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + // conservatively return false. + return err == errServerClosedIdle } // closeConnIfStillIdle closes the connection if it's still sitting idle. @@ -2300,6 +2110,45 @@ func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", err)) } +// setExtraHeaders Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). +func (pc *persistConn) setExtraHeaders(req *transportRequest) bool { + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + // TODO(spongehah) gzip(pc.roundTrip) + //if !pc.t.DisableCompression && + // req.Header.Get("Accept-Encoding") == "" && + // req.Header.Get("Range") == "" && + // req.Method != "HEAD" { + // // Request gzip only, not deflate. Deflate is ambiguous and + // // not as universally supported anyway. + // // See: https://zlib.net/zlib_faq.html#faq39 + // // + // // Note that we don't request this for HEAD requests, + // // due to a bug in nginx: + // // https://trac.nginx.org/nginx/ticket/358 + // // https://golang.org/issue/5522 + // // + // // We don't request gzip if the request is for a range, since + // // auto-decoding a portion of a gzipped document will just fail + // // anyway. See https://golang.org/issue/8923 + // requestedGzip = true + // req.extraHeaders().Set("Accept-Encoding", "gzip") + //} + + // The 100-continue operation in Hyper is handled in the newHyperRequest function. + + // Keep-Alive + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + return requestedGzip +} + func is408Message(resp *hyper.Response) bool { httpVersion := int(resp.Version()) if httpVersion != 10 && httpVersion != 11 { @@ -2435,7 +2284,6 @@ func (w *wantConn) cancel(t *Transport, err error) { w.err = err w.mu.Unlock() - // TODO(spongehah) ConnPool(w.cancel) if pc != nil { t.putOrCloseIdleConn(pc) } @@ -2534,110 +2382,6 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } -// bodyEOFSignal is used by the HTTP/1 transport when reading response -// bodies to make sure we see the end of a response body before -// proceeding and reading on the connection again. -// -// It wraps a ReadCloser but runs fn (if non-nil) at most -// once, right before its final (error-producing) Read or Close call -// returns. fn should return the new error to return from Read or Close. -// -// If earlyCloseFn is non-nil and Close is called before io.EOF is -// seen, earlyCloseFn is called instead of fn, and its return value is -// the return value from Close. -type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards following 4 fields - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) error // err will be nil on Read io.EOF - earlyCloseFn func() error // optional alt Close func used if io.EOF not seen -} - -var errReadOnClosedResBody = errors.New("http: read on closed response body") - -func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { - es.mu.Lock() - closed, rerr := es.closed, es.rerr - es.mu.Unlock() - if closed { - return 0, errReadOnClosedResBody - } - if rerr != nil { - return 0, rerr - } - - n, err = es.body.Read(p) - if err != nil { - es.mu.Lock() - defer es.mu.Unlock() - if es.rerr == nil { - es.rerr = err - } - err = es.condfn(err) - } - return -} - -func (es *bodyEOFSignal) Close() error { - es.mu.Lock() - defer es.mu.Unlock() - if es.closed { - return nil - } - es.closed = true - if es.earlyCloseFn != nil && es.rerr != io.EOF { - return es.earlyCloseFn() - } - err := es.body.Close() - return es.condfn(err) -} - -// caller must hold es.mu. -func (es *bodyEOFSignal) condfn(err error) error { - if es.fn == nil { - return err - } - err = es.fn(err) - es.fn = nil - return err -} - -// gzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read -type gzipReader struct { - _ incomparable - body *bodyEOFSignal // underlying HTTP/1 response body framing - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // any error from gzip.NewReader; sticky -} - -func (gz *gzipReader) Read(p []byte) (n int, err error) { - if gz.zr == nil { - if gz.zerr == nil { - gz.zr, gz.zerr = gzip.NewReader(gz.body) - } - if gz.zerr != nil { - return 0, gz.zerr - } - } - - gz.body.mu.Lock() - if gz.body.closed { - err = errReadOnClosedResBody - } - gz.body.mu.Unlock() - - if err != nil { - return 0, err - } - return gz.zr.Read(p) -} - -func (gz *gzipReader) Close() error { - return gz.body.Close() -} - type connLRU struct { ll *list.List // list.Element.Value type is of *persistConn m map[*persistConn]*list.Element