From a02e126411b7dcf4d701066d749f98d556345138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 29 Jul 2024 16:25:23 +0800 Subject: [PATCH 01/21] WIP(x/http/get): Implementing get request using native socket --- go.mod | 2 +- go.sum | 2 + x/http/_demo/get/get.go | 17 ++ x/http/_demo/test.go | 141 +++++++++++++++ x/http/client.go | 388 ++++++++++++++++++++++++++++++++++++++++ x/http/header.go | 34 ++++ x/http/hyper-go.go | 12 ++ x/http/response.go | 87 +++++++++ 8 files changed, 682 insertions(+), 1 deletion(-) create mode 100644 x/http/_demo/get/get.go create mode 100644 x/http/_demo/test.go create mode 100644 x/http/client.go create mode 100644 x/http/header.go create mode 100644 x/http/hyper-go.go create mode 100644 x/http/response.go diff --git a/go.mod b/go.mod index 7e12cef..ff4ef8f 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b +require github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 diff --git a/go.sum b/go.sum index fdc017f..8799390 100644 --- a/go.sum +++ b/go.sum @@ -6,3 +6,5 @@ github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3 h1:2fZ2zQ8S58KvOsJTx github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b h1:z9FUoeAALL5ytBhhGhE1dXm4+L1Q2eMUTcfiqLAZgf8= github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 h1:02gSx3Oj3cLlBMed+9IWBUGHThEZMnCNiR67yaQbpqo= +github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go new file mode 100644 index 0000000..e7591a6 --- /dev/null +++ b/x/http/_demo/get/get.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + // 使用 http.Get 发送 GET 请求 + resp := http.Get("https://www.baidu.com/") + fmt.Println(resp.Status) + fmt.Println(resp.StatusCode) + resp.PrintHeaders() + fmt.Println() + resp.PrintBody() +} diff --git a/x/http/_demo/test.go b/x/http/_demo/test.go new file mode 100644 index 0000000..2ddd3ea --- /dev/null +++ b/x/http/_demo/test.go @@ -0,0 +1,141 @@ +package main + +import ( + "fmt" + "io" +) + +func main() { + + // 假设你有一个 []byte 数组 + data := []byte("This is some data that needs to be stored in Body.") + + // 创建一个 io.Pipe + pr, pw := io.Pipe() + + // 启动一个 goroutine 将数据写入 io.Pipe 的写入端 + go func() { + defer pw.Close() // 确保写入完成后关闭写入端 + + if _, err := pw.Write(data); err != nil { + fmt.Println("Error writing to pipe:", err) + return + } + }() + + // 读取 Body 中的数据进行验证 + readData, err := io.ReadAll(pr) + if err != nil { + fmt.Println("Error reading from Body:", err) + return + } + + // 输出 Body 中的数据 + fmt.Println("Body content:", string(readData)) + // + //http.Get() + + //r, w := io.Pipe() + // + //go func() { + // fmt.Fprint(w, "some io.Reader stream to be read\n") + // w.Close() + //}() + // + //if _, err := io.Copy(os.Stdout, r); err != nil { + // log.Fatal(err) + //} + + // 使用 http.Get 发送 GET 请求 + //resp, err := http.Get("https://www.baidu.com/") + //if err != nil { + // fmt.Println("Error:", err) + // return + //} + //defer resp.Body.Close() + // + //body, err := io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println("Error reading response:", err) + // return + //} + //fmt.Println("GET Response:\n", string(body)) + + //rawURL := "http://example.com:8080/path/to/resource?query=123#fragment" + //parsedURL, err := url.Parse(rawURL) + //if err != nil { + // fmt.Println("Error parsing URL:", err) + // return + //} + // + //hostname := parsedURL.Hostname() + //port := parsedURL.Port() + // + //uri := parsedURL.RequestURI() + // + //fmt.Println("Hostname:", hostname) + //fmt.Println("Port:", port) + //fmt.Println("URI:", uri) + + //// 使用 http.Post 发送 POST 请求上传文件 + //file, err := os.Open("path/to/your/file.jpg") + //if err != nil { + // fmt.Println("Error opening file:", err) + // return + //} + //defer file.Close() + // + //var buf bytes.Buffer + //writer := multipart.NewWriter(&buf) + //_, err = writer.CreateFormFile("file", "file.jpg") + //if err != nil { + // fmt.Println("Error creating form file:", err) + // return + //} + // + //_, err = io.ReadAll(file) + //if err != nil { + // fmt.Println("Error reading file:", err) + // return + //} + // + //err = writer.Close() + //if err != nil { + // fmt.Println("Error closing writer:", err) + // return + //} + // + //resp, err = http.Post("https://www.baidu.com/upload", writer.FormDataContentType(), &buf) + //if err != nil { + // fmt.Println("Error:", err) + // return + //} + //defer resp.Body.Close() + // + //body, err = io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println("Error reading response:", err) + // return + //} + //fmt.Println("POST Response:\n", string(body)) + // + //// 使用 http.PostForm 发送表单数据 + //formData := url.Values{ + // "key": {"Value"}, + // "id": {"123"}, + //} + // + //resp, err = http.PostForm("https://www.baidu.com/form", formData) + //if err != nil { + // fmt.Println("Error:", err) + // return + //} + //defer resp.Body.Close() + // + //body, err = io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println("Error reading response:", err) + // return + //} + //fmt.Println("POST Form Response:\n", string(body)) +} diff --git a/x/http/client.go b/x/http/client.go new file mode 100644 index 0000000..32cf658 --- /dev/null +++ b/x/http/client.go @@ -0,0 +1,388 @@ +package http + +import ( + "fmt" + "strconv" + "strings" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/net" + "github.com/goplus/llgo/c/os" + "github.com/goplus/llgo/c/sys" + "github.com/goplus/llgo/c/syscall" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type ConnData struct { + Fd c.Int + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + +type RequestConfig struct { + ReqMethod string + ReqHost string + ReqPort string + ReqUri string + ReqHeaders map[string]string + ReqHTTPVersion hyper.HTTPVersion + TimeoutSec int64 + TimeoutUsec int32 + //ReqBody + //ReqURIParts +} + +func Get(url string) *Response { + host, port, uri := parseURL(url) + req := hyper.NewRequest() + + // Prepare the request + // Set the request method and uri + if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { + panic(fmt.Sprintf("error setting method %s\n", "GET")) + } + if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + panic(fmt.Sprintf("error setting uri %s\n", uri)) + } + + // Set the request headers + reqHeaders := req.Headers() + if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + panic("error setting headers\n") + } + + //var response RequestResponse + + fd := ConnectTo(host, port) + + connData := NewConnData(fd) + + // Hookup the IO + io := NewIoWithConnReadWrite(connData) + + // We need an executor generally to poll futures + exec := hyper.NewExecutor() + + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(exec) + + handshakeTask := hyper.Handshake(io, opts) + SetUserData(handshakeTask, hyper.ExampleHandshake) + + // Let's wait for the handshake to finish... + exec.Push(handshakeTask) + + var fdsRead, fdsWrite, fdsExcep syscall.FdSet + var err *hyper.Error + var response Response + + // The polling state machine! + for { + // Poll all ready tasks and act on them... + for { + task := exec.Poll() + + if task == nil { + break + } + + switch (hyper.ExampleId)(uintptr(task.Userdata())) { + case hyper.ExampleHandshake: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake error!\n")) + err = (*hyper.Error)(task.Value()) + Fail(err) + } + if task.Type() != hyper.TaskClientConn { + c.Printf(c.Str("unexpected task type\n")) + Fail(err) + } + + client := (*hyper.ClientConn)(task.Value()) + task.Free() + + // Send it! + sendTask := client.Send(req) + SetUserData(sendTask, hyper.ExampleSend) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + panic("error send\n") + } + + // For this example, no longer need the client + client.Free() + + break + case hyper.ExampleSend: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send error!\n")) + err = (*hyper.Error)(task.Value()) + Fail(err) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + Fail(err) + } + + // Take the results + resp := (*hyper.Response)(task.Value()) + task.Free() + + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() + + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) + + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + respBody := resp.Body() + + foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) + + SetUserData(foreachTask, hyper.ExampleRespBody) + exec.Push(foreachTask) + + // No longer need the response + resp.Free() + + break + case hyper.ExampleRespBody: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + err = (*hyper.Error)(task.Value()) + Fail(err) + } + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + Fail(err) + } + + // Cleaning up before exiting + task.Free() + exec.Free() + FreeConnData(connData) + + if response.respBodyWriter != nil { + defer response.respBodyWriter.Close() + } + + return &response + case hyper.ExampleNotSet: + // A background task for hyper_client completed... + task.Free() + break + } + } + + // All futures are pending on IO work, so select on the fds. + + sys.FD_ZERO(&fdsRead) + sys.FD_ZERO(&fdsWrite) + sys.FD_ZERO(&fdsExcep) + + if connData.ReadWaker != nil { + sys.FD_SET(connData.Fd, &fdsRead) + } + if connData.WriteWaker != nil { + sys.FD_SET(connData.Fd, &fdsWrite) + } + + // Set the default request timeout + var tv syscall.Timeval + tv.Sec = 10 + + selRet := sys.Select(connData.Fd+1, &fdsRead, &fdsWrite, &fdsExcep, &tv) + if selRet < 0 { + panic("select() error\n") + } else if selRet == 0 { + panic("select() timeout\n") + } + + if sys.FD_ISSET(connData.Fd, &fdsRead) != 0 { + connData.ReadWaker.Wake() + connData.ReadWaker = nil + } + + if sys.FD_ISSET(connData.Fd, &fdsWrite) != 0 { + connData.WriteWaker.Wake() + connData.WriteWaker = nil + } + } +} + +// ConnectTo connects to a host and port +func ConnectTo(host string, port string) c.Int { + var hints net.AddrInfo + hints.Family = net.AF_UNSPEC + hints.SockType = net.SOCK_STREAM + + var result, rp *net.AddrInfo + + if net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &result) != 0 { + panic(fmt.Sprintf("dns failed for %s\n", host)) + } + + var sfd c.Int + for rp = result; rp != nil; rp = rp.Next { + sfd = net.Socket(rp.Family, rp.SockType, rp.Protocol) + if sfd == -1 { + continue + } + if net.Connect(sfd, rp.Addr, rp.AddrLen) != -1 { + break + } + os.Close(sfd) + } + + net.Freeaddrinfo(result) + + // no address succeeded + if rp == nil || sfd < 0 { + panic(fmt.Sprintf("connect failed for %s\n", host)) + } + + if os.Fcntl(sfd, os.F_SETFL, os.O_NONBLOCK) != 0 { + panic("failed to set net to non-blocking\n") + } + return sfd +} + +// ReadCallBack is the callback for reading from a socket +func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + + ret := os.Read(conn.Fd, c.Pointer(buf), bufLen) + + if ret >= 0 { + return uintptr(ret) + } + + if os.Errno != os.EAGAIN { + c.Perror(c.Str("[read callback fail]")) + // kaboom + return hyper.IoError + } + + // would block, register interest + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + } + conn.ReadWaker = ctx.Waker() + return hyper.IoPending +} + +// WriteCallBack is the callback for writing to a socket +func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + ret := os.Write(conn.Fd, c.Pointer(buf), bufLen) + + if int(ret) >= 0 { + return uintptr(ret) + } + + if os.Errno != os.EAGAIN { + c.Perror(c.Str("[write callback fail]")) + // kaboom + return hyper.IoError + } + + // would block, register interest + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + } + conn.WriteWaker = ctx.Waker() + return hyper.IoPending +} + +// FreeConnData frees the connection data +func FreeConnData(conn *ConnData) { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } +} + +// Fail prints the error details and panics +func Fail(err *hyper.Error) { + if err != nil { + c.Printf(c.Str("error code: %d\n"), err.Code()) + // grab the error details + var errBuf [256]c.Char + errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) + + c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + c.Printf(c.Str("details: ")) + for i := 0; i < int(errLen); i++ { + c.Printf(c.Str("%c"), errBuf[i]) + } + c.Printf(c.Str("\n")) + + // clean up the error + err.Free() + panic("request failed\n") + } + return +} + +// NewConnData creates a new connection data +func NewConnData(fd c.Int) *ConnData { + return &ConnData{Fd: fd, ReadWaker: nil, WriteWaker: nil} +} + +// NewIoWithConnReadWrite creates a new IO with read and write callbacks +func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { + io := hyper.NewIo() + io.SetUserdata(c.Pointer(connData)) + io.SetRead(ReadCallBack) + io.SetWrite(WriteCallBack) + return io +} + +// parseURL Parse the URL and extract the host name, port number, and URI +func parseURL(rawURL string) (hostname, port, uri string) { + // 找到 "://" 的位置,以分隔协议和主机名 + schemeEnd := strings.Index(rawURL, "://") + if schemeEnd != -1 { + //scheme = rawURL[:schemeEnd] + rawURL = rawURL[schemeEnd+3:] + } else { + //scheme = "http" // 默认协议为 http + } + + // 找到第一个 "/" 的位置,以分隔主机名和路径 + pathStart := strings.Index(rawURL, "/") + if pathStart != -1 { + uri = rawURL[pathStart:] + rawURL = rawURL[:pathStart] + } else { + uri = "/" + } + + // 找到 ":" 的位置,以分隔主机名和端口号 + portStart := strings.LastIndex(rawURL, ":") + if portStart != -1 { + hostname = rawURL[:portStart] + port = rawURL[portStart+1:] + } else { + hostname = rawURL + port = "" // 未指定端口号 + } + + // 如果未指定端口号,根据协议设置默认端口号 + if port == "" { + //if scheme == "https" { + // port = "443" + //} else { + // port = "80" + //} + port = "80" + } + + return +} diff --git a/x/http/header.go b/x/http/header.go new file mode 100644 index 0000000..ea313a2 --- /dev/null +++ b/x/http/header.go @@ -0,0 +1,34 @@ +package http + +import ( + "fmt" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Header map[string][]string + +// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console +func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { + resp := (*Response)(userdata) + nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) + valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) + + if resp.Header == nil { + resp.Header = make(map[string][]string) + } + resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) + return hyper.IterContinue +} + +func (resp *Response) PrintHeaders() { + for key, values := range resp.Header { + fmt.Printf("%s: ", key) + for _, value := range values { + fmt.Printf(value + "; ") + } + fmt.Printf("\n") + } +} diff --git a/x/http/hyper-go.go b/x/http/hyper-go.go new file mode 100644 index 0000000..a1db081 --- /dev/null +++ b/x/http/hyper-go.go @@ -0,0 +1,12 @@ +package http + +import ( + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +// SetUserData Set the user data for the task +func SetUserData(task *hyper.Task, userData hyper.ExampleId) { + var data = userData + task.SetUserdata(c.Pointer(uintptr(data))) +} diff --git a/x/http/response.go b/x/http/response.go new file mode 100644 index 0000000..9263461 --- /dev/null +++ b/x/http/response.go @@ -0,0 +1,87 @@ +package http + +import ( + "fmt" + "io" + "unsafe" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Response struct { + Status string + StatusCode int + Header Header + ResponseBody io.ReadCloser + respBodyWriter *io.PipeWriter + ResponseBodyLen int64 +} + +// AppendToResponseBody (BodyForEachCallback) appends the body to the response +func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + resp := (*Response)(userdata) + len := chunk.Len() + buf := unsafe.Slice((*byte)(chunk.Bytes()), len) + + if resp.ResponseBody == nil { + var reader *io.PipeReader + reader, resp.respBodyWriter = io.Pipe() + resp.ResponseBody = io.ReadCloser(reader) + } + resp.ResponseBodyLen += int64(len) + var err error + go func() { + _, err = resp.respBodyWriter.Write(buf) + }() + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + return hyper.IterBreak + } + return hyper.IterContinue +} + +func (resp *Response) PrintBody() { + var buffer = make([]byte, resp.ResponseBodyLen) + for { + n, err := resp.ResponseBody.Read(buffer) + if err == io.EOF { + fmt.Printf("\n") + break + } + if err != nil { + fmt.Println("Error reading from pipe:", err) + break + } + fmt.Printf("%s", string(buffer[:n])) + } +} + +//// AppendToResponseBody (BodyForEachCallback) appends the body to the response +//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// resp := (*Response)(userdata) +// buf := chunk.Bytes() +// len := chunk.Len() +// responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) +// if responseBody == nil { +// c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) +// return hyper.IterBreak +// } +// +// // Copy the existing response body to the new buffer +// if resp.ResponseBody != nil { +// c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) +// c.Free(c.Pointer(resp.ResponseBody)) +// } +// +// // Append the new data +// c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) +// resp.ResponseBody = responseBody +// resp.ResponseBodyLen += len +// return hyper.IterContinue +//} + +//func (resp *Response) PrintBody() { +// //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) +// fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) +//} From dc55abcfc2205d17eaf81867a8f2fb6f7158c9f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 30 Jul 2024 16:16:20 +0800 Subject: [PATCH 02/21] feat(x/http/get): Using libuv to speed up http.Get() --- go.mod | 2 +- go.sum | 12 +- x/http/_demo/get/get.go | 11 +- x/http/_demo/test.go | 141 -------------- x/http/client.go | 393 +++++++++++++++++++++++++--------------- x/http/header.go | 4 +- x/http/hyper-go.go | 12 -- x/http/response.go | 117 ++++++------ 8 files changed, 325 insertions(+), 367 deletions(-) delete mode 100644 x/http/_demo/test.go delete mode 100644 x/http/hyper-go.go diff --git a/go.mod b/go.mod index ff4ef8f..39080a4 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 +require github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be diff --git a/go.sum b/go.sum index 8799390..e2c5d17 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,2 @@ -github.com/goplus/llgo v0.9.0 h1:yaJzQperGUafEaHc9VlVQVskIngacoTNweEXY0GRi0Q= -github.com/goplus/llgo v0.9.0/go.mod h1:M3UwiYdPZFyx7m2J0+6Ti1dYVA3uOO1WvSBocuE8N7M= -github.com/goplus/llgo v0.9.1-0.20240709104849-d6a38a567fda h1:UIPwlgzCb8dV/7WFMyprhZuq8CSLAQIqwFpH5AhrNOM= -github.com/goplus/llgo v0.9.1-0.20240709104849-d6a38a567fda/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= -github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3 h1:2fZ2zQ8S58KvOsJTx6s6MHoi6n1K4sqQwIbTauMrgEE= -github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= -github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b h1:z9FUoeAALL5ytBhhGhE1dXm4+L1Q2eMUTcfiqLAZgf8= -github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= -github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 h1:02gSx3Oj3cLlBMed+9IWBUGHThEZMnCNiR67yaQbpqo= -github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be h1:FTALxA3ivIeVRAO93e1hCSCLaPbjKn+RZx40p5lx8KE= +github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index e7591a6..09f32a0 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -8,10 +8,17 @@ import ( func main() { // 使用 http.Get 发送 GET 请求 - resp := http.Get("https://www.baidu.com/") + resp, err := http.Get("https://www.baidu.com/") + if err != nil { + fmt.Println(err) + return + } fmt.Println(resp.Status) fmt.Println(resp.StatusCode) resp.PrintHeaders() fmt.Println() - resp.PrintBody() + resp.PrintBody2() + + resp.PrintBody1() + defer resp.Content.Close() } diff --git a/x/http/_demo/test.go b/x/http/_demo/test.go deleted file mode 100644 index 2ddd3ea..0000000 --- a/x/http/_demo/test.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "fmt" - "io" -) - -func main() { - - // 假设你有一个 []byte 数组 - data := []byte("This is some data that needs to be stored in Body.") - - // 创建一个 io.Pipe - pr, pw := io.Pipe() - - // 启动一个 goroutine 将数据写入 io.Pipe 的写入端 - go func() { - defer pw.Close() // 确保写入完成后关闭写入端 - - if _, err := pw.Write(data); err != nil { - fmt.Println("Error writing to pipe:", err) - return - } - }() - - // 读取 Body 中的数据进行验证 - readData, err := io.ReadAll(pr) - if err != nil { - fmt.Println("Error reading from Body:", err) - return - } - - // 输出 Body 中的数据 - fmt.Println("Body content:", string(readData)) - // - //http.Get() - - //r, w := io.Pipe() - // - //go func() { - // fmt.Fprint(w, "some io.Reader stream to be read\n") - // w.Close() - //}() - // - //if _, err := io.Copy(os.Stdout, r); err != nil { - // log.Fatal(err) - //} - - // 使用 http.Get 发送 GET 请求 - //resp, err := http.Get("https://www.baidu.com/") - //if err != nil { - // fmt.Println("Error:", err) - // return - //} - //defer resp.Body.Close() - // - //body, err := io.ReadAll(resp.Body) - //if err != nil { - // fmt.Println("Error reading response:", err) - // return - //} - //fmt.Println("GET Response:\n", string(body)) - - //rawURL := "http://example.com:8080/path/to/resource?query=123#fragment" - //parsedURL, err := url.Parse(rawURL) - //if err != nil { - // fmt.Println("Error parsing URL:", err) - // return - //} - // - //hostname := parsedURL.Hostname() - //port := parsedURL.Port() - // - //uri := parsedURL.RequestURI() - // - //fmt.Println("Hostname:", hostname) - //fmt.Println("Port:", port) - //fmt.Println("URI:", uri) - - //// 使用 http.Post 发送 POST 请求上传文件 - //file, err := os.Open("path/to/your/file.jpg") - //if err != nil { - // fmt.Println("Error opening file:", err) - // return - //} - //defer file.Close() - // - //var buf bytes.Buffer - //writer := multipart.NewWriter(&buf) - //_, err = writer.CreateFormFile("file", "file.jpg") - //if err != nil { - // fmt.Println("Error creating form file:", err) - // return - //} - // - //_, err = io.ReadAll(file) - //if err != nil { - // fmt.Println("Error reading file:", err) - // return - //} - // - //err = writer.Close() - //if err != nil { - // fmt.Println("Error closing writer:", err) - // return - //} - // - //resp, err = http.Post("https://www.baidu.com/upload", writer.FormDataContentType(), &buf) - //if err != nil { - // fmt.Println("Error:", err) - // return - //} - //defer resp.Body.Close() - // - //body, err = io.ReadAll(resp.Body) - //if err != nil { - // fmt.Println("Error reading response:", err) - // return - //} - //fmt.Println("POST Response:\n", string(body)) - // - //// 使用 http.PostForm 发送表单数据 - //formData := url.Values{ - // "key": {"Value"}, - // "id": {"123"}, - //} - // - //resp, err = http.PostForm("https://www.baidu.com/form", formData) - //if err != nil { - // fmt.Println("Error:", err) - // return - //} - //defer resp.Body.Close() - // - //body, err = io.ReadAll(resp.Body) - //if err != nil { - // fmt.Println("Error reading response:", err) - // return - //} - //fmt.Println("POST Form Response:\n", string(body)) -} diff --git a/x/http/client.go b/x/http/client.go index 32cf658..d173574 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -2,63 +2,127 @@ package http import ( "fmt" + io2 "io" "strconv" "strings" + "unsafe" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" - "github.com/goplus/llgo/c/os" - "github.com/goplus/llgo/c/sys" "github.com/goplus/llgo/c/syscall" "github.com/goplus/llgoexamples/rust/hyper" ) type ConnData struct { - Fd c.Int - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker } -type RequestConfig struct { - ReqMethod string - ReqHost string - ReqPort string - ReqUri string - ReqHeaders map[string]string - ReqHTTPVersion hyper.HTTPVersion - TimeoutSec int64 - TimeoutUsec int32 - //ReqBody - //ReqURIParts +type Client struct { + Transport RoundTripper } -func Get(url string) *Response { - host, port, uri := parseURL(url) - req := hyper.NewRequest() +var DefaultClient = &Client{} + +type RoundTripper interface { + RoundTrip(*hyper.Request) (*Response, error) +} + +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + +func Get2(url string) (*Response, error) { + return DefaultClient.Get(url) +} +func (c *Client) Get(url string) (*Response, error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +func (c *Client) Do(req *hyper.Request) (*Response, error) { + return c.do(req) +} + +func (c *Client) do(req *hyper.Request) (*Response, error) { + return c.send(req, nil) +} + +func (c *Client) send(req *hyper.Request, deadline any) (*Response, error) { + return send(req, c.transport(), deadline) +} + +func send(req *hyper.Request, rt RoundTripper, deadline any) (resp *Response, err error) { + return rt.RoundTrip(req) +} + +func NewRequest(method, url string, body io2.Reader) (*hyper.Request, error) { + host, _, uri := parseURL(url) // Prepare the request + req := hyper.NewRequest() // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { - panic(fmt.Sprintf("error setting method %s\n", "GET")) + if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", method) } if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - panic(fmt.Sprintf("error setting uri %s\n", uri)) + return nil, fmt.Errorf("error setting uri %s\n", uri) } // Set the request headers reqHeaders := req.Headers() if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - panic("error setting headers\n") + return nil, fmt.Errorf("error setting headers\n") + } + return req, nil +} + +func Get(url string) (_ *Response, err error) { + host, port, uri := parseURL(url) + + loop := libuv.DefaultLoop() + conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) + if conn == nil { + return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } - //var response RequestResponse + libuv.InitTcp(loop, &conn.TcpHandle) + //conn.TcpHandle.Data = c.Pointer(conn) + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) + + var hints net.AddrInfo + c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) + hints.Family = syscall.AF_UNSPEC + hints.SockType = syscall.SOCK_STREAM + + var res *net.AddrInfo + status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) + if status != 0 { + return nil, fmt.Errorf("getaddrinfo error\n") + } - fd := ConnectTo(host, port) + //conn.ConnectReq.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) + status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) + if status != 0 { + return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) + } - connData := NewConnData(fd) + net.Freeaddrinfo(res) // Hookup the IO - io := NewIoWithConnReadWrite(connData) + io := NewIoWithConnReadWrite(conn) // We need an executor generally to poll futures exec := hyper.NewExecutor() @@ -73,8 +137,7 @@ func Get(url string) *Response { // Let's wait for the handshake to finish... exec.Push(handshakeTask) - var fdsRead, fdsWrite, fdsExcep syscall.FdSet - var err *hyper.Error + var hyperErr *hyper.Error var response Response // The polling state machine! @@ -82,7 +145,6 @@ func Get(url string) *Response { // Poll all ready tasks and act on them... for { task := exec.Poll() - if task == nil { break } @@ -91,23 +153,41 @@ func Get(url string) *Response { case hyper.ExampleHandshake: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake error!\n")) - err = (*hyper.Error)(task.Value()) - Fail(err) + hyperErr = (*hyper.Error)(task.Value()) + err = Fail(hyperErr) + return nil, err } if task.Type() != hyper.TaskClientConn { c.Printf(c.Str("unexpected task type\n")) - Fail(err) + err = Fail(hyperErr) + return nil, err } client := (*hyper.ClientConn)(task.Value()) task.Free() + // Prepare the request + req := hyper.NewRequest() + // Set the request method and uri + if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", "GET") + } + if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + return nil, fmt.Errorf("error setting uri %s\n", uri) + } + + // Set the request headers + reqHeaders := req.Headers() + if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + return nil, fmt.Errorf("error setting headers\n") + } + // Send it! sendTask := client.Send(req) SetUserData(sendTask, hyper.ExampleSend) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { - panic("error send\n") + return nil, fmt.Errorf("error send\n") } // For this example, no longer need the client @@ -117,12 +197,14 @@ func Get(url string) *Response { case hyper.ExampleSend: if task.Type() == hyper.TaskError { c.Printf(c.Str("send error!\n")) - err = (*hyper.Error)(task.Value()) - Fail(err) + hyperErr = (*hyper.Error)(task.Value()) + err = Fail(hyperErr) + return nil, err } if task.Type() != hyper.TaskResponse { c.Printf(c.Str("unexpected task type\n")) - Fail(err) + err = Fail(hyperErr) + return nil, err } // Take the results @@ -139,133 +221,127 @@ func Get(url string) *Response { headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) respBody := resp.Body() + response.Body, response.respBodyWriter = io2.Pipe() + + /*go func() { + fmt.Println("writing...") + for { + fmt.Println("writing for...") + dataTask := respBody.Data() + exec.Push(dataTask) + dataTask = exec.Poll() + if dataTask.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(dataTask.Value()) + len := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), len) + _, err := response.respBodyWriter.Write(bytes) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + break + } + dataTask.Free() + } else if dataTask.Type() == hyper.TaskEmpty { + fmt.Println("writing empty") + dataTask.Free() + break + } + } + fmt.Println("end writing") + defer response.respBodyWriter.Close() + }()*/ + foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) SetUserData(foreachTask, hyper.ExampleRespBody) exec.Push(foreachTask) + return &response, nil + // No longer need the response - resp.Free() + //resp.Free() break case hyper.ExampleRespBody: + println("ExampleRespBody") if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) - err = (*hyper.Error)(task.Value()) - Fail(err) + hyperErr = (*hyper.Error)(task.Value()) + err = Fail(hyperErr) + return nil, err } if task.Type() != hyper.TaskEmpty { c.Printf(c.Str("unexpected task type\n")) - Fail(err) + err = Fail(hyperErr) + return nil, err } // Cleaning up before exiting task.Free() exec.Free() - FreeConnData(connData) + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) - if response.respBodyWriter != nil { - defer response.respBodyWriter.Close() - } + FreeConnData(conn) + + //if response.respBodyWriter != nil { + // defer response.respBodyWriter.Close() + //} - return &response + return &response, nil case hyper.ExampleNotSet: + println("ExampleNotSet") // A background task for hyper_client completed... task.Free() break } } - // All futures are pending on IO work, so select on the fds. - - sys.FD_ZERO(&fdsRead) - sys.FD_ZERO(&fdsWrite) - sys.FD_ZERO(&fdsExcep) - - if connData.ReadWaker != nil { - sys.FD_SET(connData.Fd, &fdsRead) - } - if connData.WriteWaker != nil { - sys.FD_SET(connData.Fd, &fdsWrite) - } - - // Set the default request timeout - var tv syscall.Timeval - tv.Sec = 10 - - selRet := sys.Select(connData.Fd+1, &fdsRead, &fdsWrite, &fdsExcep, &tv) - if selRet < 0 { - panic("select() error\n") - } else if selRet == 0 { - panic("select() timeout\n") - } - - if sys.FD_ISSET(connData.Fd, &fdsRead) != 0 { - connData.ReadWaker.Wake() - connData.ReadWaker = nil - } - - if sys.FD_ISSET(connData.Fd, &fdsWrite) != 0 { - connData.WriteWaker.Wake() - connData.WriteWaker = nil - } + libuv.Run(loop, libuv.RUN_ONCE) } } -// ConnectTo connects to a host and port -func ConnectTo(host string, port string) c.Int { - var hints net.AddrInfo - hints.Family = net.AF_UNSPEC - hints.SockType = net.SOCK_STREAM - - var result, rp *net.AddrInfo - - if net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &result) != 0 { - panic(fmt.Sprintf("dns failed for %s\n", host)) +// AllocBuffer allocates a buffer for reading from a socket +func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { + //conn := (*ConnData)(handle.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data + conn := (*ConnData)(handle.GetData()) + if conn.ReadBuf.Base == nil { + conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + conn.ReadBufFilled = 0 } + *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) +} - var sfd c.Int - for rp = result; rp != nil; rp = rp.Next { - sfd = net.Socket(rp.Family, rp.SockType, rp.Protocol) - if sfd == -1 { - continue - } - if net.Connect(sfd, rp.Addr, rp.AddrLen) != -1 { - break - } - os.Close(sfd) - } - - net.Freeaddrinfo(result) - - // no address succeeded - if rp == nil || sfd < 0 { - panic(fmt.Sprintf("connect failed for %s\n", host)) +// OnRead is the libuv callback for reading from a socket +func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { + //conn := (*ConnData)(stream.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data + conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + if nread > 0 { + conn.ReadBufFilled += uintptr(nread) } - - if os.Fcntl(sfd, os.F_SETFL, os.O_NONBLOCK) != 0 { - panic("failed to set net to non-blocking\n") + if conn.ReadWaker != nil { + conn.ReadWaker.Wake() + conn.ReadWaker = nil } - return sfd } -// ReadCallBack is the callback for reading from a socket +// ReadCallBack is the hyper callback for reading from a socket func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { conn := (*ConnData)(userdata) - ret := os.Read(conn.Fd, c.Pointer(buf), bufLen) - - if ret >= 0 { - return uintptr(ret) - } - - if os.Errno != os.EAGAIN { - c.Perror(c.Str("[read callback fail]")) - // kaboom - return hyper.IoError + if conn.ReadBufFilled > 0 { + var toCopy uintptr + if bufLen < conn.ReadBufFilled { + toCopy = bufLen + } else { + toCopy = conn.ReadBufFilled + } + c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + conn.ReadBufFilled -= toCopy + return toCopy } - // would block, register interest if conn.ReadWaker != nil { conn.ReadWaker.Free() } @@ -273,22 +349,32 @@ func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin return hyper.IoPending } -// WriteCallBack is the callback for writing to a socket -func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - conn := (*ConnData)(userdata) - ret := os.Write(conn.Fd, c.Pointer(buf), bufLen) +// OnWrite is the libuv callback for writing to a socket +func OnWrite(req *libuv.Write, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) - if int(ret) >= 0 { - return uintptr(ret) + if conn.WriteWaker != nil { + conn.WriteWaker.Wake() + conn.WriteWaker = nil } + c.Free(c.Pointer(req)) +} + +// WriteCallBack is the hyper callback for writing to a socket +func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) + req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) + //req.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) - if os.Errno != os.EAGAIN { - c.Perror(c.Str("[write callback fail]")) - // kaboom - return hyper.IoError + if ret >= 0 { + return bufLen } - // would block, register interest if conn.WriteWaker != nil { conn.WriteWaker.Free() } @@ -296,6 +382,19 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui return hyper.IoPending } +// OnConnect is the libuv callback for a successful connection +func OnConnect(req *libuv.Connect, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if status < 0 { + c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + return + } + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +} + // FreeConnData frees the connection data func FreeConnData(conn *ConnData) { if conn.ReadWaker != nil { @@ -306,10 +405,15 @@ func FreeConnData(conn *ConnData) { conn.WriteWaker.Free() conn.WriteWaker = nil } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } + c.Free(c.Pointer(conn)) } // Fail prints the error details and panics -func Fail(err *hyper.Error) { +func Fail(err *hyper.Error) error { if err != nil { c.Printf(c.Str("error code: %d\n"), err.Code()) // grab the error details @@ -317,22 +421,12 @@ func Fail(err *hyper.Error) { errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) - c.Printf(c.Str("details: ")) - for i := 0; i < int(errLen); i++ { - c.Printf(c.Str("%c"), errBuf[i]) - } - c.Printf(c.Str("\n")) // clean up the error err.Free() - panic("request failed\n") + return fmt.Errorf("hyper error\n") } - return -} - -// NewConnData creates a new connection data -func NewConnData(fd c.Int) *ConnData { - return &ConnData{Fd: fd, ReadWaker: nil, WriteWaker: nil} + return nil } // NewIoWithConnReadWrite creates a new IO with read and write callbacks @@ -344,6 +438,12 @@ func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { return io } +// SetUserData Set the user data for the task +func SetUserData(task *hyper.Task, userData hyper.ExampleId) { + var data = userData + task.SetUserdata(c.Pointer(uintptr(data))) +} + // parseURL Parse the URL and extract the host name, port number, and URI func parseURL(rawURL string) (hostname, port, uri string) { // 找到 "://" 的位置,以分隔协议和主机名 @@ -383,6 +483,5 @@ func parseURL(rawURL string) (hostname, port, uri string) { //} port = "80" } - return } diff --git a/x/http/header.go b/x/http/header.go index ea313a2..4710854 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -25,10 +25,8 @@ func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va func (resp *Response) PrintHeaders() { for key, values := range resp.Header { - fmt.Printf("%s: ", key) for _, value := range values { - fmt.Printf(value + "; ") + fmt.Printf("%s: %s\n", key, value) } - fmt.Printf("\n") } } diff --git a/x/http/hyper-go.go b/x/http/hyper-go.go deleted file mode 100644 index a1db081..0000000 --- a/x/http/hyper-go.go +++ /dev/null @@ -1,12 +0,0 @@ -package http - -import ( - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" -) - -// SetUserData Set the user data for the task -func SetUserData(task *hyper.Task, userData hyper.ExampleId) { - var data = userData - task.SetUserdata(c.Pointer(uintptr(data))) -} diff --git a/x/http/response.go b/x/http/response.go index 9263461..020f2d9 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -3,7 +3,6 @@ package http import ( "fmt" "io" - "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -13,38 +12,51 @@ type Response struct { Status string StatusCode int Header Header - ResponseBody io.ReadCloser + Content io.ReadCloser + ContentLen int64 respBodyWriter *io.PipeWriter - ResponseBodyLen int64 + ResponseBody *uint8 + ResponseBodyLen uintptr } // AppendToResponseBody (BodyForEachCallback) appends the body to the response -func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - resp := (*Response)(userdata) - len := chunk.Len() - buf := unsafe.Slice((*byte)(chunk.Bytes()), len) +//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// resp := (*Response)(userdata) +// len := chunk.Len() +// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) +// +// if resp.Content == nil { +// var reader *io.PipeReader +// reader, resp.respBodyWriter = io.Pipe() +// resp.Content = io.ReadCloser(reader) +// } +// resp.ContentLen += int64(len) +// var err error +// go func() { +// _, err = resp.respBodyWriter.Write(buf) +// }() +// if err != nil { +// fmt.Printf("Failed to write response body: %v\n", err) +// return hyper.IterBreak +// } +// return hyper.IterContinue +//} - if resp.ResponseBody == nil { - var reader *io.PipeReader - reader, resp.respBodyWriter = io.Pipe() - resp.ResponseBody = io.ReadCloser(reader) - } - resp.ResponseBodyLen += int64(len) - var err error +func (resp *Response) PrintBody1() { go func() { - _, err = resp.respBodyWriter.Write(buf) + var reader *io.PipeReader + reader, writer := io.Pipe() + resp.Content = reader + writer.Write((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen]) + defer writer.Close() }() - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - return hyper.IterBreak + for i := 0; i < 10; i++ { + c.Usleep(1 * 1000 * 1000) + fmt.Println("Sleeping...") } - return hyper.IterContinue -} - -func (resp *Response) PrintBody() { - var buffer = make([]byte, resp.ResponseBodyLen) + var buffer = make([]byte, 4096) for { - n, err := resp.ResponseBody.Read(buffer) + n, err := resp.Content.Read(buffer) if err == io.EOF { fmt.Printf("\n") break @@ -55,33 +67,36 @@ func (resp *Response) PrintBody() { } fmt.Printf("%s", string(buffer[:n])) } + buffer = nil + //body, _ := io.ReadAll(resp.Content) + //fmt.Println(string(body)) } -//// AppendToResponseBody (BodyForEachCallback) appends the body to the response -//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// resp := (*Response)(userdata) -// buf := chunk.Bytes() -// len := chunk.Len() -// responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) -// if responseBody == nil { -// c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) -// return hyper.IterBreak -// } -// -// // Copy the existing response body to the new buffer -// if resp.ResponseBody != nil { -// c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) -// c.Free(c.Pointer(resp.ResponseBody)) -// } -// -// // Append the new data -// c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) -// resp.ResponseBody = responseBody -// resp.ResponseBodyLen += len -// return hyper.IterContinue -//} +// AppendToResponseBody (BodyForEachCallback) appends the body to the response +func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + resp := (*Response)(userdata) + buf := chunk.Bytes() + len := chunk.Len() + responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) + if responseBody == nil { + c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) + return hyper.IterBreak + } -//func (resp *Response) PrintBody() { -// //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) -// fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) -//} + // Copy the existing response body to the new buffer + if resp.ResponseBody != nil { + c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) + c.Free(c.Pointer(resp.ResponseBody)) + } + + // Append the new data + c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) + resp.ResponseBody = responseBody + resp.ResponseBodyLen += len + return hyper.IterContinue +} + +func (resp *Response) PrintBody2() { + //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) + fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) +} From 685154ff53b5f2eb74fe0857325d80bc0e8b0797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 2 Aug 2024 18:08:38 +0800 Subject: [PATCH 03/21] WIP(x/http-get): Use channels to pass responses --- x/httpget/_demo/get/get.go | 24 ++ x/httpget/client.go | 46 ++++ x/httpget/header.go | 32 +++ x/httpget/request.go | 42 ++++ x/httpget/response.go | 54 ++++ x/httpget/transport.go | 502 +++++++++++++++++++++++++++++++++++++ 6 files changed, 700 insertions(+) create mode 100644 x/httpget/_demo/get/get.go create mode 100644 x/httpget/client.go create mode 100644 x/httpget/header.go create mode 100644 x/httpget/request.go create mode 100644 x/httpget/response.go create mode 100644 x/httpget/transport.go diff --git a/x/httpget/_demo/get/get.go b/x/httpget/_demo/get/get.go new file mode 100644 index 0000000..da674ba --- /dev/null +++ b/x/httpget/_demo/get/get.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/x/httpget" +) + +func main() { + resp, err := httpget.Get("www.baidu.com") + //req, _ := httpget.NewRequest("GET", "http://www.baidu.com", nil) + //resp, err := httpget.DefaultClient.Send(req, nil) + if err != nil { + fmt.Println(err) + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) +} diff --git a/x/httpget/client.go b/x/httpget/client.go new file mode 100644 index 0000000..8a1f610 --- /dev/null +++ b/x/httpget/client.go @@ -0,0 +1,46 @@ +package httpget + +type Client struct { + Transport RoundTripper +} + +var DefaultClient = &Client{} + +type RoundTripper interface { + RoundTrip(*Request) (*Response, error) +} + +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + +func Get(url string) (*Response, error) { + return DefaultClient.Get(url) +} + +func (c *Client) Get(url string) (*Response, error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req) +} + +func (c *Client) do(req *Request) (*Response, error) { + return c.send(req, nil) +} + +func (c *Client) send(req *Request, deadline any) (*Response, error) { + return send(req, c.transport(), deadline) +} + +func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { + return rt.RoundTrip(req) +} diff --git a/x/httpget/header.go b/x/httpget/header.go new file mode 100644 index 0000000..1768557 --- /dev/null +++ b/x/httpget/header.go @@ -0,0 +1,32 @@ +package httpget + +import ( + "fmt" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Header map[string][]string + +// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console +func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { + resp := (*Response)(userdata) + nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) + valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) + + if resp.Header == nil { + resp.Header = make(map[string][]string) + } + resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) + return hyper.IterContinue +} + +func (resp *Response) PrintHeaders() { + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } +} diff --git a/x/httpget/request.go b/x/httpget/request.go new file mode 100644 index 0000000..391311c --- /dev/null +++ b/x/httpget/request.go @@ -0,0 +1,42 @@ +package httpget + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Request struct { + Method string + Url string +} + +func NewRequest(method, url string, body io.Reader) (*Request, error) { + return &Request{ + Method: method, + Url: url, + }, nil +} + +func NewHyperRequest(request *Request) (*hyper.Request, error) { + host, _, uri := parseURL(request.Url) + method := request.Method + // Prepare the request + req := hyper.NewRequest() + // Set the request method and uri + if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", method) + } + if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + return nil, fmt.Errorf("error setting uri %s\n", uri) + } + + // Set the request headers + reqHeaders := req.Headers() + if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + return nil, fmt.Errorf("error setting headers\n") + } + return req, nil +} diff --git a/x/httpget/response.go b/x/httpget/response.go new file mode 100644 index 0000000..a9e4468 --- /dev/null +++ b/x/httpget/response.go @@ -0,0 +1,54 @@ +package httpget + +import ( + "fmt" + "io" + "unsafe" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Response struct { + Status string + StatusCode int + Header Header + Body io.ReadCloser + ContentLength int64 + respBodyWriter *io.PipeWriter +} + +// AppendToResponseBody (BodyForEachCallback) appends the body to the response +func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + fmt.Println("reading1...") + resp := (*Response)(userdata) + len := chunk.Len() + buf := unsafe.Slice((*byte)(chunk.Bytes()), len) + _, err := resp.respBodyWriter.Write(buf) + resp.ContentLength += int64(len) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + return hyper.IterBreak + } + fmt.Println("reading2...") + return hyper.IterContinue +} + +func (resp *Response) PrintBody() { + var buffer = make([]byte, 4096) + for { + n, err := resp.Body.Read(buffer) + if err == io.EOF { + fmt.Printf("\n") + break + } + if err != nil { + fmt.Println("Error reading from pipe:", err) + break + } + fmt.Printf("%s", string(buffer[:n])) + } + buffer = nil + //body, _ := io.ReadAll(resp.Content) + //fmt.Println(string(body)) +} diff --git a/x/httpget/transport.go b/x/httpget/transport.go new file mode 100644 index 0000000..0450bbd --- /dev/null +++ b/x/httpget/transport.go @@ -0,0 +1,502 @@ +package httpget + +import ( + "bufio" + "fmt" + io2 "io" + "strconv" + "strings" + "unsafe" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/libuv" + "github.com/goplus/llgo/c/net" + "github.com/goplus/llgo/c/syscall" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type ConnData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + +type Transport struct { +} + +var DefaultTransport RoundTripper = &Transport{} + +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + // alt optionally specifies the TLS NextProto RoundTripper. + // This is used for HTTP/2 today and future protocols later. + // If it's non-nil, the rest of the fields are unused. + alt RoundTripper + + conn *ConnData + t *Transport + br *bufio.Reader // from conn + bw *bufio.Writer // to conn + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // closed when conn closed +} + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type incomparable [0]func() + +type requestAndChan struct { + _ incomparable + req *hyper.Request + ch chan responseAndError // unbuffered; always send in select on callerGone +} + +// A writeRequest is sent by the caller's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + // req *transportRequest + ch chan<- error + + // Optional blocking chan for Expect: 100-continue (for receive). + // If not nil, writeLoop blocks sending request body until + // it receives from this chan. + continueCh <-chan struct{} +} + +// responseAndError is how the goroutine reading from an HTTP/1 server +// communicates with the goroutine doing the RoundTrip. +type responseAndError struct { + _ incomparable + res *Response // else use this response (see res method) + err error +} + +func (t *Transport) RoundTrip(request *Request) (*Response, error) { + req, err := NewHyperRequest(request) + if err != nil { + return nil, err + } + pconn, err := t.getConn(req) + var resp *Response + resp, err = pconn.roundTrip(req) + if err == nil { + return resp, nil + } + return nil, err +} + +func (t *Transport) getConn(req *hyper.Request) (pconn *persistConn, err error) { + host := "www.baidu.com" + port := "80" + loop := libuv.DefaultLoop() + conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) + if conn == nil { + return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") + } + + libuv.InitTcp(loop, &conn.TcpHandle) + //conn.TcpHandle.Data = c.Pointer(conn) + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) + + var hints net.AddrInfo + c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) + hints.Family = syscall.AF_UNSPEC + hints.SockType = syscall.SOCK_STREAM + + var res *net.AddrInfo + status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) + if status != 0 { + return nil, fmt.Errorf("getaddrinfo error\n") + } + + //conn.ConnectReq.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) + status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) + if status != 0 { + return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) + } + pconn = &persistConn{ + conn: conn, + t: t, + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + } + + net.Freeaddrinfo(res) + + go pconn.startLoop(loop) + return pconn, nil +} + +func (pc *persistConn) roundTrip(req *hyper.Request) (resp *Response, err error) { + resc := make(chan responseAndError) + pc.reqch <- requestAndChan{ + req: req, + ch: resc, + } + + select { + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if re.err != nil { + return nil, err + } + return re.res, nil + } +} + +func (pc *persistConn) startLoop(loop *libuv.Loop) { + // Hookup the IO + io := NewIoWithConnReadWrite(pc.conn) + + // We need an executor generally to poll futures + exec := hyper.NewExecutor() + + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(exec) + + handshakeTask := hyper.Handshake(io, opts) + SetUserData(handshakeTask, hyper.ExampleHandshake) + + // Let's wait for the handshake to finish... + exec.Push(handshakeTask) + + var hyperErr *hyper.Error + var response Response + + var rc requestAndChan + + select { + case rc = <-pc.reqch: + } + // The polling state machine! + for { + // Poll all ready tasks and act on them... + for { + task := exec.Poll() + if task == nil { + break + } + + switch (hyper.ExampleId)(uintptr(task.Userdata())) { + case hyper.ExampleHandshake: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskClientConn { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + + client := (*hyper.ClientConn)(task.Value()) + task.Free() + + // Send it! + sendTask := client.Send(rc.req) + SetUserData(sendTask, hyper.ExampleSend) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + panic("error send\n") + } + + // For this example, no longer need the client + client.Free() + + break + case hyper.ExampleSend: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + + // Take the results + resp := (*hyper.Response)(task.Value()) + task.Free() + + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() + + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) + + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + respBody := resp.Body() + + response.Body, response.respBodyWriter = io2.Pipe() + + /*go func() { + fmt.Println("writing...") + for { + fmt.Println("writing for...") + dataTask := respBody.Data() + exec.Push(dataTask) + dataTask = exec.Poll() + if dataTask.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(dataTask.Value()) + len := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), len) + _, err := response.respBodyWriter.Write(bytes) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + break + } + dataTask.Free() + } else if dataTask.Type() == hyper.TaskEmpty { + fmt.Println("writing empty") + dataTask.Free() + break + } + } + fmt.Println("end writing") + defer response.respBodyWriter.Close() + }()*/ + + foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) + + SetUserData(foreachTask, hyper.ExampleRespBody) + exec.Push(foreachTask) + + rc.ch <- responseAndError{res: &response} + // No longer need the response + //resp.Free() + + break + case hyper.ExampleRespBody: + println("ExampleRespBody") + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + + // Cleaning up before exiting + task.Free() + //exec.Free() + (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + + FreeConnData(pc.conn) + + //return &response, nil + break + case hyper.ExampleNotSet: + println("ExampleNotSet") + // A background task for hyper_client completed... + task.Free() + break + } + } + + libuv.Run(loop, libuv.RUN_ONCE) + } +} + +// AllocBuffer allocates a buffer for reading from a socket +func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { + //conn := (*ConnData)(handle.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data + conn := (*ConnData)(handle.GetData()) + if conn.ReadBuf.Base == nil { + conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + conn.ReadBufFilled = 0 + } + *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) +} + +// OnRead is the libuv callback for reading from a socket +func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { + //conn := (*ConnData)(stream.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data + conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + if nread > 0 { + conn.ReadBufFilled += uintptr(nread) + } + if conn.ReadWaker != nil { + conn.ReadWaker.Wake() + conn.ReadWaker = nil + } +} + +// ReadCallBack is the hyper callback for reading from a socket +func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + + if conn.ReadBufFilled > 0 { + var toCopy uintptr + if bufLen < conn.ReadBufFilled { + toCopy = bufLen + } else { + toCopy = conn.ReadBufFilled + } + c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + conn.ReadBufFilled -= toCopy + return toCopy + } + + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + } + conn.ReadWaker = ctx.Waker() + return hyper.IoPending +} + +// OnWrite is the libuv callback for writing to a socket +func OnWrite(req *libuv.Write, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if conn.WriteWaker != nil { + conn.WriteWaker.Wake() + conn.WriteWaker = nil + } + c.Free(c.Pointer(req)) +} + +// WriteCallBack is the hyper callback for writing to a socket +func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) + req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) + //req.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + + if ret >= 0 { + return bufLen + } + + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + } + conn.WriteWaker = ctx.Waker() + return hyper.IoPending +} + +// OnConnect is the libuv callback for a successful connection +func OnConnect(req *libuv.Connect, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if status < 0 { + c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + return + } + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +} + +// FreeConnData frees the connection data +func FreeConnData(conn *ConnData) { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } + c.Free(c.Pointer(conn)) +} + +// Fail prints the error details and panics +func Fail(err *hyper.Error) { + if err != nil { + c.Printf(c.Str("error code: %d\n"), err.Code()) + // grab the error details + var errBuf [256]c.Char + errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) + + c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + + // clean up the error + err.Free() + panic("hyper error \n") + } +} + +// NewIoWithConnReadWrite creates a new IO with read and write callbacks +func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { + io := hyper.NewIo() + io.SetUserdata(c.Pointer(connData)) + io.SetRead(ReadCallBack) + io.SetWrite(WriteCallBack) + return io +} + +// SetUserData Set the user data for the task +func SetUserData(task *hyper.Task, userData hyper.ExampleId) { + var data = userData + task.SetUserdata(c.Pointer(uintptr(data))) +} + +// parseURL Parse the URL and extract the host name, port number, and URI +func parseURL(rawURL string) (hostname, port, uri string) { + // 找到 "://" 的位置,以分隔协议和主机名 + schemeEnd := strings.Index(rawURL, "://") + if schemeEnd != -1 { + //scheme = rawURL[:schemeEnd] + rawURL = rawURL[schemeEnd+3:] + } else { + //scheme = "http" // 默认协议为 http + } + + // 找到第一个 "/" 的位置,以分隔主机名和路径 + pathStart := strings.Index(rawURL, "/") + if pathStart != -1 { + uri = rawURL[pathStart:] + rawURL = rawURL[:pathStart] + } else { + uri = "/" + } + + // 找到 ":" 的位置,以分隔主机名和端口号 + portStart := strings.LastIndex(rawURL, ":") + if portStart != -1 { + hostname = rawURL[:portStart] + port = rawURL[portStart+1:] + } else { + hostname = rawURL + port = "" // 未指定端口号 + } + + // 如果未指定端口号,根据协议设置默认端口号 + if port == "" { + //if scheme == "https" { + // port = "443" + //} else { + // port = "80" + //} + port = "80" + } + return +} From c52bdd5f06fd7f22a7c0cb3c4709de030c457d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 5 Aug 2024 17:03:35 +0800 Subject: [PATCH 04/21] WIP(x/http/client/get): Use channels to pass response(Passed the test) --- go.mod | 2 +- go.sum | 4 +- x/http/_demo/get/get.go | 19 +- x/http/client.go | 453 +------------------------------ x/{httpget => http}/request.go | 25 +- x/http/response.go | 93 +------ x/{httpget => http}/transport.go | 360 +++++++++++------------- x/httpget/_demo/get/get.go | 24 -- x/httpget/client.go | 46 ---- x/httpget/header.go | 32 --- x/httpget/response.go | 54 ---- 11 files changed, 201 insertions(+), 911 deletions(-) rename x/{httpget => http}/request.go (65%) rename x/{httpget => http}/transport.go (56%) delete mode 100644 x/httpget/_demo/get/get.go delete mode 100644 x/httpget/client.go delete mode 100644 x/httpget/header.go delete mode 100644 x/httpget/response.go diff --git a/go.mod b/go.mod index 39080a4..fa05f1f 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be +require github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c diff --git a/go.sum b/go.sum index e2c5d17..ba1d000 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be h1:FTALxA3ivIeVRAO93e1hCSCLaPbjKn+RZx40p5lx8KE= -github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c h1:PhaSnZL8LLyRIHWc5Wim9No0Q475H8Ljikxfj1gHHjc= +github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index 09f32a0..73e7113 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -2,23 +2,24 @@ package main import ( "fmt" + "io" "github.com/goplus/llgoexamples/x/http" ) func main() { - // 使用 http.Get 发送 GET 请求 - resp, err := http.Get("https://www.baidu.com/") + resp, err := http.Get("https://www.baidu.com") if err != nil { fmt.Println(err) return } - fmt.Println(resp.Status) - fmt.Println(resp.StatusCode) + println(resp.Status) resp.PrintHeaders() - fmt.Println() - resp.PrintBody2() - - resp.PrintBody1() - defer resp.Content.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() } diff --git a/x/http/client.go b/x/http/client.go index d173574..ac0bc6e 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,28 +1,5 @@ package http -import ( - "fmt" - io2 "io" - "strconv" - "strings" - "unsafe" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgo/c/libuv" - "github.com/goplus/llgo/c/net" - "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgoexamples/rust/hyper" -) - -type ConnData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} - type Client struct { Transport RoundTripper } @@ -30,7 +7,7 @@ type Client struct { var DefaultClient = &Client{} type RoundTripper interface { - RoundTrip(*hyper.Request) (*Response, error) + RoundTrip(*Request) (*Response, error) } func (c *Client) transport() RoundTripper { @@ -40,7 +17,7 @@ func (c *Client) transport() RoundTripper { return DefaultTransport } -func Get2(url string) (*Response, error) { +func Get(url string) (*Response, error) { return DefaultClient.Get(url) } @@ -52,436 +29,18 @@ func (c *Client) Get(url string) (*Response, error) { return c.Do(req) } -func (c *Client) Do(req *hyper.Request) (*Response, error) { +func (c *Client) Do(req *Request) (*Response, error) { return c.do(req) } -func (c *Client) do(req *hyper.Request) (*Response, error) { +func (c *Client) do(req *Request) (*Response, error) { return c.send(req, nil) } -func (c *Client) send(req *hyper.Request, deadline any) (*Response, error) { +func (c *Client) send(req *Request, deadline any) (*Response, error) { return send(req, c.transport(), deadline) } -func send(req *hyper.Request, rt RoundTripper, deadline any) (resp *Response, err error) { +func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { return rt.RoundTrip(req) } - -func NewRequest(method, url string, body io2.Reader) (*hyper.Request, error) { - host, _, uri := parseURL(url) - // Prepare the request - req := hyper.NewRequest() - // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { - return nil, fmt.Errorf("error setting method %s\n", method) - } - if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - return nil, fmt.Errorf("error setting uri %s\n", uri) - } - - // Set the request headers - reqHeaders := req.Headers() - if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting headers\n") - } - return req, nil -} - -func Get(url string) (_ *Response, err error) { - host, port, uri := parseURL(url) - - loop := libuv.DefaultLoop() - conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) - if conn == nil { - return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") - } - - libuv.InitTcp(loop, &conn.TcpHandle) - //conn.TcpHandle.Data = c.Pointer(conn) - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) - - var hints net.AddrInfo - c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) - hints.Family = syscall.AF_UNSPEC - hints.SockType = syscall.SOCK_STREAM - - var res *net.AddrInfo - status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) - if status != 0 { - return nil, fmt.Errorf("getaddrinfo error\n") - } - - //conn.ConnectReq.Data = c.Pointer(conn) - (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) - status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) - if status != 0 { - return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) - } - - net.Freeaddrinfo(res) - - // Hookup the IO - io := NewIoWithConnReadWrite(conn) - - // We need an executor generally to poll futures - exec := hyper.NewExecutor() - - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(exec) - - handshakeTask := hyper.Handshake(io, opts) - SetUserData(handshakeTask, hyper.ExampleHandshake) - - // Let's wait for the handshake to finish... - exec.Push(handshakeTask) - - var hyperErr *hyper.Error - var response Response - - // The polling state machine! - for { - // Poll all ready tasks and act on them... - for { - task := exec.Poll() - if task == nil { - break - } - - switch (hyper.ExampleId)(uintptr(task.Userdata())) { - case hyper.ExampleHandshake: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - err = Fail(hyperErr) - return nil, err - } - if task.Type() != hyper.TaskClientConn { - c.Printf(c.Str("unexpected task type\n")) - err = Fail(hyperErr) - return nil, err - } - - client := (*hyper.ClientConn)(task.Value()) - task.Free() - - // Prepare the request - req := hyper.NewRequest() - // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { - return nil, fmt.Errorf("error setting method %s\n", "GET") - } - if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - return nil, fmt.Errorf("error setting uri %s\n", uri) - } - - // Set the request headers - reqHeaders := req.Headers() - if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting headers\n") - } - - // Send it! - sendTask := client.Send(req) - SetUserData(sendTask, hyper.ExampleSend) - sendRes := exec.Push(sendTask) - if sendRes != hyper.OK { - return nil, fmt.Errorf("error send\n") - } - - // For this example, no longer need the client - client.Free() - - break - case hyper.ExampleSend: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("send error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - err = Fail(hyperErr) - return nil, err - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - err = Fail(hyperErr) - return nil, err - } - - // Take the results - resp := (*hyper.Response)(task.Value()) - task.Free() - - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody := resp.Body() - - response.Body, response.respBodyWriter = io2.Pipe() - - /*go func() { - fmt.Println("writing...") - for { - fmt.Println("writing for...") - dataTask := respBody.Data() - exec.Push(dataTask) - dataTask = exec.Poll() - if dataTask.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(dataTask.Value()) - len := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), len) - _, err := response.respBodyWriter.Write(bytes) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - break - } - dataTask.Free() - } else if dataTask.Type() == hyper.TaskEmpty { - fmt.Println("writing empty") - dataTask.Free() - break - } - } - fmt.Println("end writing") - defer response.respBodyWriter.Close() - }()*/ - - foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) - - SetUserData(foreachTask, hyper.ExampleRespBody) - exec.Push(foreachTask) - - return &response, nil - - // No longer need the response - //resp.Free() - - break - case hyper.ExampleRespBody: - println("ExampleRespBody") - if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - err = Fail(hyperErr) - return nil, err - } - if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - err = Fail(hyperErr) - return nil, err - } - - // Cleaning up before exiting - task.Free() - exec.Free() - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) - - FreeConnData(conn) - - //if response.respBodyWriter != nil { - // defer response.respBodyWriter.Close() - //} - - return &response, nil - case hyper.ExampleNotSet: - println("ExampleNotSet") - // A background task for hyper_client completed... - task.Free() - break - } - } - - libuv.Run(loop, libuv.RUN_ONCE) - } -} - -// AllocBuffer allocates a buffer for reading from a socket -func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { - //conn := (*ConnData)(handle.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data - conn := (*ConnData)(handle.GetData()) - if conn.ReadBuf.Base == nil { - conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) - conn.ReadBufFilled = 0 - } - *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) -} - -// OnRead is the libuv callback for reading from a socket -func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { - //conn := (*ConnData)(stream.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data - conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - if nread > 0 { - conn.ReadBufFilled += uintptr(nread) - } - if conn.ReadWaker != nil { - conn.ReadWaker.Wake() - conn.ReadWaker = nil - } -} - -// ReadCallBack is the hyper callback for reading from a socket -func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - conn := (*ConnData)(userdata) - - if conn.ReadBufFilled > 0 { - var toCopy uintptr - if bufLen < conn.ReadBufFilled { - toCopy = bufLen - } else { - toCopy = conn.ReadBufFilled - } - c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) - c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) - conn.ReadBufFilled -= toCopy - return toCopy - } - - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - } - conn.ReadWaker = ctx.Waker() - return hyper.IoPending -} - -// OnWrite is the libuv callback for writing to a socket -func OnWrite(req *libuv.Write, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) - - if conn.WriteWaker != nil { - conn.WriteWaker.Wake() - conn.WriteWaker = nil - } - c.Free(c.Pointer(req)) -} - -// WriteCallBack is the hyper callback for writing to a socket -func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - conn := (*ConnData)(userdata) - initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) - req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) - //req.Data = c.Pointer(conn) - (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) - - if ret >= 0 { - return bufLen - } - - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - } - conn.WriteWaker = ctx.Waker() - return hyper.IoPending -} - -// OnConnect is the libuv callback for a successful connection -func OnConnect(req *libuv.Connect, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) - - if status < 0 { - c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) - return - } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) -} - -// FreeConnData frees the connection data -func FreeConnData(conn *ConnData) { - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil - } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil - } - c.Free(c.Pointer(conn)) -} - -// Fail prints the error details and panics -func Fail(err *hyper.Error) error { - if err != nil { - c.Printf(c.Str("error code: %d\n"), err.Code()) - // grab the error details - var errBuf [256]c.Char - errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - - c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) - - // clean up the error - err.Free() - return fmt.Errorf("hyper error\n") - } - return nil -} - -// NewIoWithConnReadWrite creates a new IO with read and write callbacks -func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { - io := hyper.NewIo() - io.SetUserdata(c.Pointer(connData)) - io.SetRead(ReadCallBack) - io.SetWrite(WriteCallBack) - return io -} - -// SetUserData Set the user data for the task -func SetUserData(task *hyper.Task, userData hyper.ExampleId) { - var data = userData - task.SetUserdata(c.Pointer(uintptr(data))) -} - -// parseURL Parse the URL and extract the host name, port number, and URI -func parseURL(rawURL string) (hostname, port, uri string) { - // 找到 "://" 的位置,以分隔协议和主机名 - schemeEnd := strings.Index(rawURL, "://") - if schemeEnd != -1 { - //scheme = rawURL[:schemeEnd] - rawURL = rawURL[schemeEnd+3:] - } else { - //scheme = "http" // 默认协议为 http - } - - // 找到第一个 "/" 的位置,以分隔主机名和路径 - pathStart := strings.Index(rawURL, "/") - if pathStart != -1 { - uri = rawURL[pathStart:] - rawURL = rawURL[:pathStart] - } else { - uri = "/" - } - - // 找到 ":" 的位置,以分隔主机名和端口号 - portStart := strings.LastIndex(rawURL, ":") - if portStart != -1 { - hostname = rawURL[:portStart] - port = rawURL[portStart+1:] - } else { - hostname = rawURL - port = "" // 未指定端口号 - } - - // 如果未指定端口号,根据协议设置默认端口号 - if port == "" { - //if scheme == "https" { - // port = "443" - //} else { - // port = "80" - //} - port = "80" - } - return -} diff --git a/x/httpget/request.go b/x/http/request.go similarity index 65% rename from x/httpget/request.go rename to x/http/request.go index 391311c..98d48b4 100644 --- a/x/httpget/request.go +++ b/x/http/request.go @@ -1,8 +1,9 @@ -package httpget +package http import ( "fmt" "io" + "net/url" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -10,19 +11,29 @@ import ( type Request struct { Method string - Url string + URL *url.URL + Req *hyper.Request } -func NewRequest(method, url string, body io.Reader) (*Request, error) { +func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + parseURL, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + req, err := NewHyperRequest(method, parseURL) + if err != nil { + return nil, err + } return &Request{ Method: method, - Url: url, + URL: parseURL, + Req: req, }, nil } -func NewHyperRequest(request *Request) (*hyper.Request, error) { - host, _, uri := parseURL(request.Url) - method := request.Method +func NewHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { + host := URL.Hostname() + uri := URL.RequestURI() // Prepare the request req := hyper.NewRequest() // Set the request method and uri diff --git a/x/http/response.go b/x/http/response.go index 020f2d9..8d01b80 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -3,100 +3,31 @@ package http import ( "fmt" "io" + "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" ) type Response struct { - Status string - StatusCode int - Header Header - Content io.ReadCloser - ContentLen int64 - respBodyWriter *io.PipeWriter - ResponseBody *uint8 - ResponseBodyLen uintptr -} - -// AppendToResponseBody (BodyForEachCallback) appends the body to the response -//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// resp := (*Response)(userdata) -// len := chunk.Len() -// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) -// -// if resp.Content == nil { -// var reader *io.PipeReader -// reader, resp.respBodyWriter = io.Pipe() -// resp.Content = io.ReadCloser(reader) -// } -// resp.ContentLen += int64(len) -// var err error -// go func() { -// _, err = resp.respBodyWriter.Write(buf) -// }() -// if err != nil { -// fmt.Printf("Failed to write response body: %v\n", err) -// return hyper.IterBreak -// } -// return hyper.IterContinue -//} - -func (resp *Response) PrintBody1() { - go func() { - var reader *io.PipeReader - reader, writer := io.Pipe() - resp.Content = reader - writer.Write((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen]) - defer writer.Close() - }() - for i := 0; i < 10; i++ { - c.Usleep(1 * 1000 * 1000) - fmt.Println("Sleeping...") - } - var buffer = make([]byte, 4096) - for { - n, err := resp.Content.Read(buffer) - if err == io.EOF { - fmt.Printf("\n") - break - } - if err != nil { - fmt.Println("Error reading from pipe:", err) - break - } - fmt.Printf("%s", string(buffer[:n])) - } - buffer = nil - //body, _ := io.ReadAll(resp.Content) - //fmt.Println(string(body)) + Status string + StatusCode int + Header Header + Body io.ReadCloser + ContentLength int64 + respBodyWriter *io.PipeWriter } // AppendToResponseBody (BodyForEachCallback) appends the body to the response func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { resp := (*Response)(userdata) - buf := chunk.Bytes() len := chunk.Len() - responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) - if responseBody == nil { - c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) + buf := unsafe.Slice((*byte)(chunk.Bytes()), len) + _, err := resp.respBodyWriter.Write(buf) + resp.ContentLength += int64(len) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) return hyper.IterBreak } - - // Copy the existing response body to the new buffer - if resp.ResponseBody != nil { - c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) - c.Free(c.Pointer(resp.ResponseBody)) - } - - // Append the new data - c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) - resp.ResponseBody = responseBody - resp.ResponseBodyLen += len return hyper.IterContinue } - -func (resp *Response) PrintBody2() { - //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) - fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) -} diff --git a/x/httpget/transport.go b/x/http/transport.go similarity index 56% rename from x/httpget/transport.go rename to x/http/transport.go index 0450bbd..c2559a8 100644 --- a/x/httpget/transport.go +++ b/x/http/transport.go @@ -1,11 +1,9 @@ -package httpget +package http import ( - "bufio" "fmt" io2 "io" "strconv" - "strings" "unsafe" "github.com/goplus/llgo/c" @@ -27,6 +25,16 @@ type ConnData struct { type Transport struct { } +// TaskId The unique identifier of the next task polled from the executor +type TaskId c.Int + +const ( + NotSet TaskId = iota + Send + ReceiveResp + ReceiveRespBody +) + var DefaultTransport RoundTripper = &Transport{} // persistConn wraps a connection, usually a persistent one @@ -35,16 +43,15 @@ type persistConn struct { // alt optionally specifies the TLS NextProto RoundTripper. // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. - alt RoundTripper - - conn *ConnData - t *Transport - br *bufio.Reader // from conn - bw *bufio.Writer // to conn - nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; read by readLoop - writech chan writeRequest // written by roundTrip; read by writeLoop - closech chan struct{} // closed when conn closed + //alt RoundTripper + //br *bufio.Reader // from conn + //bw *bufio.Writer // to conn + //nwrite int64 // bytes written + //writech chan writeRequest // written by roundTrip; read by writeLoop + //closech chan struct{} // closed when conn closed + conn *ConnData + t *Transport + reqch chan requestAndChan // written by roundTrip; read by readLoop } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -58,20 +65,6 @@ type requestAndChan struct { ch chan responseAndError // unbuffered; always send in select on callerGone } -// A writeRequest is sent by the caller's goroutine to the -// writeLoop's goroutine to write a request while the read loop -// concurrently waits on both the write response and the server's -// reply. -type writeRequest struct { - // req *transportRequest - ch chan<- error - - // Optional blocking chan for Expect: 100-continue (for receive). - // If not nil, writeLoop blocks sending request body until - // it receives from this chan. - continueCh <-chan struct{} -} - // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { @@ -80,23 +73,26 @@ type responseAndError struct { err error } -func (t *Transport) RoundTrip(request *Request) (*Response, error) { - req, err := NewHyperRequest(request) +func (t *Transport) RoundTrip(req *Request) (*Response, error) { + pconn, err := t.getConn(req) if err != nil { return nil, err } - pconn, err := t.getConn(req) var resp *Response resp, err = pconn.roundTrip(req) - if err == nil { - return resp, nil + if err != nil { + return nil, err } - return nil, err + return resp, nil } -func (t *Transport) getConn(req *hyper.Request) (pconn *persistConn, err error) { - host := "www.baidu.com" - port := "80" +func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { + host := req.URL.Hostname() + port := req.URL.Port() + if port == "" { + // Hyper only supports http + port = "80" + } loop := libuv.DefaultLoop() conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) if conn == nil { @@ -125,23 +121,23 @@ func (t *Transport) getConn(req *hyper.Request) (pconn *persistConn, err error) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } pconn = &persistConn{ - conn: conn, - t: t, - reqch: make(chan requestAndChan, 1), - writech: make(chan writeRequest, 1), - closech: make(chan struct{}), + conn: conn, + t: t, + reqch: make(chan requestAndChan, 1), + //writech: make(chan writeRequest, 1), + //closech: make(chan struct{}), } net.Freeaddrinfo(res) - go pconn.startLoop(loop) + go pconn.readWriteLoop(loop) return pconn, nil } -func (pc *persistConn) roundTrip(req *hyper.Request) (resp *Response, err error) { +func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { resc := make(chan responseAndError) pc.reqch <- requestAndChan{ - req: req, + req: req.Req, ch: resc, } @@ -157,7 +153,7 @@ func (pc *persistConn) roundTrip(req *hyper.Request) (resp *Response, err error) } } -func (pc *persistConn) startLoop(loop *libuv.Loop) { +func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Hookup the IO io := NewIoWithConnReadWrite(pc.conn) @@ -169,150 +165,140 @@ func (pc *persistConn) startLoop(loop *libuv.Loop) { opts.Exec(exec) handshakeTask := hyper.Handshake(io, opts) - SetUserData(handshakeTask, hyper.ExampleHandshake) + SetTaskId(handshakeTask, Send) // Let's wait for the handshake to finish... exec.Push(handshakeTask) + // The polling state machine! + //for { + // Poll all ready tasks and act on them... + rc := <-pc.reqch // blocking + alive := true var hyperErr *hyper.Error var response Response + var respBody *hyper.Body = nil + for alive { + task := exec.Poll() + if task == nil { + //break + libuv.Run(loop, libuv.RUN_ONCE) + continue + } - var rc requestAndChan + switch (TaskId)(uintptr(task.Userdata())) { + case Send: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskClientConn { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } - select { - case rc = <-pc.reqch: - } - // The polling state machine! - for { - // Poll all ready tasks and act on them... - for { - task := exec.Poll() - if task == nil { - break + client := (*hyper.ClientConn)(task.Value()) + task.Free() + + // Send it! + sendTask := client.Send(rc.req) + SetTaskId(sendTask, ReceiveResp) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + panic("error send\n") } - switch (hyper.ExampleId)(uintptr(task.Userdata())) { - case hyper.ExampleHandshake: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskClientConn { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) - } + // For this example, no longer need the client + client.Free() - client := (*hyper.ClientConn)(task.Value()) - task.Free() + case ReceiveResp: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } - // Send it! - sendTask := client.Send(rc.req) - SetUserData(sendTask, hyper.ExampleSend) - sendRes := exec.Push(sendTask) - if sendRes != hyper.OK { - panic("error send\n") - } + // Take the results + resp := (*hyper.Response)(task.Value()) + task.Free() - // For this example, no longer need the client - client.Free() + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() - break - case hyper.ExampleSend: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("send error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) - } + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) - // Take the results - resp := (*hyper.Response)(task.Value()) - task.Free() + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + //respBody := resp.Body() + respBody = resp.Body() - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody := resp.Body() - - response.Body, response.respBodyWriter = io2.Pipe() - - /*go func() { - fmt.Println("writing...") - for { - fmt.Println("writing for...") - dataTask := respBody.Data() - exec.Push(dataTask) - dataTask = exec.Poll() - if dataTask.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(dataTask.Value()) - len := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), len) - _, err := response.respBodyWriter.Write(bytes) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - break - } - dataTask.Free() - } else if dataTask.Type() == hyper.TaskEmpty { - fmt.Println("writing empty") - dataTask.Free() - break - } - } - fmt.Println("end writing") - defer response.respBodyWriter.Close() - }()*/ - - foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) - - SetUserData(foreachTask, hyper.ExampleRespBody) - exec.Push(foreachTask) - - rc.ch <- responseAndError{res: &response} - // No longer need the response - //resp.Free() + response.Body, response.respBodyWriter = io2.Pipe() - break - case hyper.ExampleRespBody: - println("ExampleRespBody") - if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) - } + //foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) + //SetTaskId(foreachTask, ReceiveRespBody) + //exec.Push(foreachTask) - // Cleaning up before exiting - task.Free() - //exec.Free() - (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + rc.ch <- responseAndError{res: &response} - FreeConnData(pc.conn) + dataTask := respBody.Data() + SetTaskId(dataTask, ReceiveRespBody) + exec.Push(dataTask) - //return &response, nil - break - case hyper.ExampleNotSet: - println("ExampleNotSet") - // A background task for hyper_client completed... + // No longer need the response + resp.Free() + + case ReceiveRespBody: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(task.Value()) + bufLen := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) + _, err := response.respBodyWriter.Write(bytes) // blocking + if err != nil { + panic("[readWriteLoop(): case ReceiveRespBody] error write\n") + } + buf.Free() task.Free() + + dataTask := respBody.Data() + SetTaskId(dataTask, ReceiveRespBody) + exec.Push(dataTask) + break } + // task.Type() == hyper.TaskEmpty + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + // Cleaning up before exiting + task.Free() + respBody.Free() + response.respBodyWriter.Close() + exec.Free() + (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + FreeConnData(pc.conn) + + close(rc.ch) + close(pc.reqch) + + alive = false + case NotSet: + // A background task for hyper_client completed... + task.Free() } - - libuv.Run(loop, libuv.RUN_ONCE) } + //} } // AllocBuffer allocates a buffer for reading from a socket @@ -453,50 +439,8 @@ func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { return io } -// SetUserData Set the user data for the task -func SetUserData(task *hyper.Task, userData hyper.ExampleId) { +// SetTaskId Set TaskId to the task's userdata as a unique identifier +func SetTaskId(task *hyper.Task, userData TaskId) { var data = userData task.SetUserdata(c.Pointer(uintptr(data))) } - -// parseURL Parse the URL and extract the host name, port number, and URI -func parseURL(rawURL string) (hostname, port, uri string) { - // 找到 "://" 的位置,以分隔协议和主机名 - schemeEnd := strings.Index(rawURL, "://") - if schemeEnd != -1 { - //scheme = rawURL[:schemeEnd] - rawURL = rawURL[schemeEnd+3:] - } else { - //scheme = "http" // 默认协议为 http - } - - // 找到第一个 "/" 的位置,以分隔主机名和路径 - pathStart := strings.Index(rawURL, "/") - if pathStart != -1 { - uri = rawURL[pathStart:] - rawURL = rawURL[:pathStart] - } else { - uri = "/" - } - - // 找到 ":" 的位置,以分隔主机名和端口号 - portStart := strings.LastIndex(rawURL, ":") - if portStart != -1 { - hostname = rawURL[:portStart] - port = rawURL[portStart+1:] - } else { - hostname = rawURL - port = "" // 未指定端口号 - } - - // 如果未指定端口号,根据协议设置默认端口号 - if port == "" { - //if scheme == "https" { - // port = "443" - //} else { - // port = "80" - //} - port = "80" - } - return -} diff --git a/x/httpget/_demo/get/get.go b/x/httpget/_demo/get/get.go deleted file mode 100644 index da674ba..0000000 --- a/x/httpget/_demo/get/get.go +++ /dev/null @@ -1,24 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/goplus/llgo/x/httpget" -) - -func main() { - resp, err := httpget.Get("www.baidu.com") - //req, _ := httpget.NewRequest("GET", "http://www.baidu.com", nil) - //resp, err := httpget.DefaultClient.Send(req, nil) - if err != nil { - fmt.Println(err) - return - } - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(string(body)) -} diff --git a/x/httpget/client.go b/x/httpget/client.go deleted file mode 100644 index 8a1f610..0000000 --- a/x/httpget/client.go +++ /dev/null @@ -1,46 +0,0 @@ -package httpget - -type Client struct { - Transport RoundTripper -} - -var DefaultClient = &Client{} - -type RoundTripper interface { - RoundTrip(*Request) (*Response, error) -} - -func (c *Client) transport() RoundTripper { - if c.Transport != nil { - return c.Transport - } - return DefaultTransport -} - -func Get(url string) (*Response, error) { - return DefaultClient.Get(url) -} - -func (c *Client) Get(url string) (*Response, error) { - req, err := NewRequest("GET", url, nil) - if err != nil { - return nil, err - } - return c.Do(req) -} - -func (c *Client) Do(req *Request) (*Response, error) { - return c.do(req) -} - -func (c *Client) do(req *Request) (*Response, error) { - return c.send(req, nil) -} - -func (c *Client) send(req *Request, deadline any) (*Response, error) { - return send(req, c.transport(), deadline) -} - -func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { - return rt.RoundTrip(req) -} diff --git a/x/httpget/header.go b/x/httpget/header.go deleted file mode 100644 index 1768557..0000000 --- a/x/httpget/header.go +++ /dev/null @@ -1,32 +0,0 @@ -package httpget - -import ( - "fmt" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" -) - -type Header map[string][]string - -// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console -func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { - resp := (*Response)(userdata) - nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) - valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) - - if resp.Header == nil { - resp.Header = make(map[string][]string) - } - resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) - //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) - return hyper.IterContinue -} - -func (resp *Response) PrintHeaders() { - for key, values := range resp.Header { - for _, value := range values { - fmt.Printf("%s: %s\n", key, value) - } - } -} diff --git a/x/httpget/response.go b/x/httpget/response.go deleted file mode 100644 index a9e4468..0000000 --- a/x/httpget/response.go +++ /dev/null @@ -1,54 +0,0 @@ -package httpget - -import ( - "fmt" - "io" - "unsafe" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" -) - -type Response struct { - Status string - StatusCode int - Header Header - Body io.ReadCloser - ContentLength int64 - respBodyWriter *io.PipeWriter -} - -// AppendToResponseBody (BodyForEachCallback) appends the body to the response -func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - fmt.Println("reading1...") - resp := (*Response)(userdata) - len := chunk.Len() - buf := unsafe.Slice((*byte)(chunk.Bytes()), len) - _, err := resp.respBodyWriter.Write(buf) - resp.ContentLength += int64(len) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - return hyper.IterBreak - } - fmt.Println("reading2...") - return hyper.IterContinue -} - -func (resp *Response) PrintBody() { - var buffer = make([]byte, 4096) - for { - n, err := resp.Body.Read(buffer) - if err == io.EOF { - fmt.Printf("\n") - break - } - if err != nil { - fmt.Println("Error reading from pipe:", err) - break - } - fmt.Printf("%s", string(buffer[:n])) - } - buffer = nil - //body, _ := io.ReadAll(resp.Content) - //fmt.Println(string(body)) -} From 763dd7227d621dbc4dc0a9eb1de98496c31983e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 8 Aug 2024 18:04:50 +0800 Subject: [PATCH 05/21] WIP(x/http/client/get): Some code optimization and comment addition --- go.mod | 2 +- go.sum | 4 +- x/http/_demo/get/get.go | 2 +- x/http/request.go | 15 +- x/http/response.go | 40 +++--- x/http/transport.go | 305 +++++++++++++++++++++++++++------------- 6 files changed, 240 insertions(+), 128 deletions(-) diff --git a/go.mod b/go.mod index fa05f1f..978043d 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c +require github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 diff --git a/go.sum b/go.sum index ba1d000..17cad08 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c h1:PhaSnZL8LLyRIHWc5Wim9No0Q475H8Ljikxfj1gHHjc= -github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 h1:il9j5kdSnaoO57XJ8ebSHppPWIJ8iwqgcegOJNkipt4= +github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index 73e7113..c8460e4 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -13,7 +13,7 @@ func main() { fmt.Println(err) return } - println(resp.Status) + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/http/request.go b/x/http/request.go index 98d48b4..23e98b4 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -4,15 +4,19 @@ import ( "fmt" "io" "net/url" + "time" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" ) type Request struct { - Method string - URL *url.URL - Req *hyper.Request + Method string + URL *url.URL + Req *hyper.Request + Host string + Header Header + Timeout time.Duration } func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { @@ -20,7 +24,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if err != nil { return nil, err } - req, err := NewHyperRequest(method, parseURL) + req, err := newHyperRequest(method, parseURL) if err != nil { return nil, err } @@ -28,10 +32,11 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { Method: method, URL: parseURL, Req: req, + Host: parseURL.Hostname(), }, nil } -func NewHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { +func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { host := URL.Hostname() uri := URL.RequestURI() // Prepare the request diff --git a/x/http/response.go b/x/http/response.go index 8d01b80..2f3a641 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -1,33 +1,27 @@ package http import ( - "fmt" "io" - "unsafe" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" ) type Response struct { - Status string - StatusCode int - Header Header - Body io.ReadCloser - ContentLength int64 - respBodyWriter *io.PipeWriter + Status string + StatusCode int + Header Header + Body io.ReadCloser + ContentLength int64 } // AppendToResponseBody (BodyForEachCallback) appends the body to the response -func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - resp := (*Response)(userdata) - len := chunk.Len() - buf := unsafe.Slice((*byte)(chunk.Bytes()), len) - _, err := resp.respBodyWriter.Write(buf) - resp.ContentLength += int64(len) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - return hyper.IterBreak - } - return hyper.IterContinue -} +//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// resp := (*Response)(userdata) +// len := chunk.Len() +// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) +// _, err := resp.respBodyWriter.Write(buf) +// resp.ContentLength += int64(len) +// if err != nil { +// fmt.Printf("Failed to write response body: %v\n", err) +// return hyper.IterBreak +// } +// return hyper.IterContinue +//} diff --git a/x/http/transport.go b/x/http/transport.go index c2559a8..ed08e45 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -2,7 +2,7 @@ package http import ( "fmt" - io2 "io" + "io" "strconv" "unsafe" @@ -35,6 +35,10 @@ const ( ReceiveRespBody ) +const ( + DefaultHTTPPort = "80" +) + var DefaultTransport RoundTripper = &Transport{} // persistConn wraps a connection, usually a persistent one @@ -87,14 +91,15 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { - host := req.URL.Hostname() + host := req.Host port := req.URL.Port() if port == "" { // Hyper only supports http - port = "80" + port = DefaultHTTPPort } loop := libuv.DefaultLoop() - conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) + //conn := (*ConnData)(c.Calloc(1, unsafe.Sizeof(ConnData{}))) + conn := new(ConnData) if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } @@ -144,7 +149,7 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { select { case re := <-resc: if (re.res == nil) == (re.err == nil) { - panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } if re.err != nil { return nil, err @@ -153,18 +158,19 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { } } +// readWriteLoop handles the main I/O loop for a persistent connection. +// It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Hookup the IO - io := NewIoWithConnReadWrite(pc.conn) + hyperIo := NewIoWithConnReadWrite(pc.conn) // We need an executor generally to poll futures exec := hyper.NewExecutor() - // Prepare client options opts := hyper.NewClientConnOptions() opts.Exec(exec) - handshakeTask := hyper.Handshake(io, opts) + handshakeTask := hyper.Handshake(hyperIo, opts) SetTaskId(handshakeTask, Send) // Let's wait for the handshake to finish... @@ -175,27 +181,25 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Poll all ready tasks and act on them... rc := <-pc.reqch // blocking alive := true - var hyperErr *hyper.Error var response Response + var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { task := exec.Poll() if task == nil { //break - libuv.Run(loop, libuv.RUN_ONCE) + loop.Run(libuv.RUN_ONCE) continue } switch (TaskId)(uintptr(task.Userdata())) { case Send: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskClientConn { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) + err := CheckTaskType(task, Send) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } client := (*hyper.ClientConn)(task.Value()) @@ -206,21 +210,21 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { SetTaskId(sendTask, ReceiveResp) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { - panic("error send\n") + rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } // For this example, no longer need the client client.Free() - case ReceiveResp: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("send error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) + err := CheckTaskType(task, ReceiveResp) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } // Take the results @@ -235,14 +239,25 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { headers := resp.Headers() headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - //respBody := resp.Body() respBody = resp.Body() - response.Body, response.respBodyWriter = io2.Pipe() + response.Body, bodyWriter = io.Pipe() - //foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) - //SetTaskId(foreachTask, ReceiveRespBody) - //exec.Push(foreachTask) + // TODO(spongehah) Replace header operations with using the textproto package + lengthSlice := response.Header["content-length"] + if lengthSlice == nil { + response.ContentLength = 0 + } else { + contentLength := response.Header["content-length"][0] + length, err := strconv.Atoi(contentLength) + if err != nil { + rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + response.ContentLength = int64(length) + } rc.ch <- responseAndError{res: &response} @@ -252,20 +267,31 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // No longer need the response resp.Free() - case ReceiveRespBody: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) + err := CheckTaskType(task, ReceiveRespBody) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } + if task.Type() == hyper.TaskBuf { buf := (*hyper.Buf)(task.Value()) bufLen := buf.Len() bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) - _, err := response.respBodyWriter.Write(bytes) // blocking + if bodyWriter == nil { + rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + _, err := bodyWriter.Write(bytes) // blocking if err != nil { - panic("[readWriteLoop(): case ReceiveRespBody] error write\n") + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } buf.Free() task.Free() @@ -276,21 +302,18 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { break } - // task.Type() == hyper.TaskEmpty + + // We are done with the response body if task.Type() != hyper.TaskEmpty { c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) + rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } - // Cleaning up before exiting - task.Free() - respBody.Free() - response.respBodyWriter.Close() - exec.Free() - (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) - FreeConnData(pc.conn) - close(rc.ch) - close(pc.reqch) + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) alive = false case NotSet: @@ -301,6 +324,19 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} } +// OnConnect is the libuv callback for a successful connection +func OnConnect(req *libuv.Connect, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if status < 0 { + c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + return + } + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +} + // AllocBuffer allocates a buffer for reading from a socket func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { //conn := (*ConnData)(handle.Data) @@ -314,108 +350,160 @@ func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { } // OnRead is the libuv callback for reading from a socket +// This callback function is called when data is available to be read func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { + // Get the connection data associated with the stream + conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) //conn := (*ConnData)(stream.Data) //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data - conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + + // If data was read (nread > 0) if nread > 0 { + // Update the amount of filled buffer conn.ReadBufFilled += uintptr(nread) } + // If there's a pending read waker if conn.ReadWaker != nil { + // Wake up the pending read operation of Hyper conn.ReadWaker.Wake() + // Clear the waker reference conn.ReadWaker = nil } } -// ReadCallBack is the hyper callback for reading from a socket +// ReadCallBack read callback function for Hyper library func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + // Get the user data (connection data) conn := (*ConnData)(userdata) + // If there's data in the buffer if conn.ReadBufFilled > 0 { + // Calculate how much data to copy (minimum of filled amount and requested amount) var toCopy uintptr if bufLen < conn.ReadBufFilled { toCopy = bufLen } else { toCopy = conn.ReadBufFilled } + // Copy data from read buffer to Hyper's buffer c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + // Move remaining data to the beginning of the buffer c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + // Update the amount of filled buffer conn.ReadBufFilled -= toCopy + // Return the number of bytes copied return toCopy } + // If no data in buffer, set up a waker to wait for more data + // Free the old waker if it exists if conn.ReadWaker != nil { conn.ReadWaker.Free() } + // Create a new waker conn.ReadWaker = ctx.Waker() + // Return HYPER_IO_PENDING to indicate operation is pending, waiting for more data return hyper.IoPending } // OnWrite is the libuv callback for writing to a socket +// Callback function called after a write operation completes func OnWrite(req *libuv.Write, status c.Int) { + // Get the connection data associated with the write request + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) //conn := (*ConnData)(req.Data) //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + // If there's a pending write waker if conn.WriteWaker != nil { + // Wake up the pending write operation conn.WriteWaker.Wake() + // Clear the waker reference conn.WriteWaker = nil } - c.Free(c.Pointer(req)) } -// WriteCallBack is the hyper callback for writing to a socket +// WriteCallBack write callback function for Hyper library func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + // Get the user data (connection data) conn := (*ConnData)(userdata) + // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) - req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) - //req.Data = c.Pointer(conn) + //req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) + req := &libuv.Write{} + // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + //req.Data = c.Pointer(conn) + // Perform the asynchronous write operation + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + // If the write operation was successfully initiated if ret >= 0 { + // Return the number of bytes to be written return bufLen } + // If the write operation can't complete immediately, set up a waker to wait for completion if conn.WriteWaker != nil { + // Free the old waker if it exists conn.WriteWaker.Free() } + // Create a new waker conn.WriteWaker = ctx.Waker() + // Return HYPER_IO_PENDING to indicate operation is pending, waiting for write to complete return hyper.IoPending } -// OnConnect is the libuv callback for a successful connection -func OnConnect(req *libuv.Connect, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) +// NewIoWithConnReadWrite creates a new IO with read and write callbacks +func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { + hyperIo := hyper.NewIo() + hyperIo.SetUserdata(c.Pointer(connData)) + hyperIo.SetRead(ReadCallBack) + hyperIo.SetWrite(WriteCallBack) + return hyperIo +} - if status < 0 { - c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) - return - } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +// SetTaskId Set TaskId to the task's userdata as a unique identifier +func SetTaskId(task *hyper.Task, userData TaskId) { + var data = userData + task.SetUserdata(unsafe.Pointer(uintptr(data))) } -// FreeConnData frees the connection data -func FreeConnData(conn *ConnData) { - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil - } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil +// CheckTaskType checks the task type +func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { + switch curTaskId { + case Send: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake task error!\n")) + return Fail((*hyper.Error)(task.Value())) + } + if task.Type() != hyper.TaskClientConn { + return fmt.Errorf("unexpected task type\n") + } + return nil + case ReceiveResp: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send task error!\n")) + return Fail((*hyper.Error)(task.Value())) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + return fmt.Errorf("unexpected task type\n") + } + return nil + case ReceiveRespBody: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + return Fail((*hyper.Error)(task.Value())) + } + return nil + case NotSet: } - c.Free(c.Pointer(conn)) + return fmt.Errorf("unexpected TaskId\n") } // Fail prints the error details and panics -func Fail(err *hyper.Error) { +func Fail(err *hyper.Error) error { if err != nil { c.Printf(c.Str("error code: %d\n"), err.Code()) // grab the error details @@ -426,21 +514,46 @@ func Fail(err *hyper.Error) { // clean up the error err.Free() - panic("hyper error \n") + return fmt.Errorf("hyper request error, error code: %d\n", int(err.Code())) } + return nil } -// NewIoWithConnReadWrite creates a new IO with read and write callbacks -func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { - io := hyper.NewIo() - io.SetUserdata(c.Pointer(connData)) - io.SetRead(ReadCallBack) - io.SetWrite(WriteCallBack) - return io +// FreeResources frees the resources +func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { + // Cleaning up before exiting + if task != nil { + task.Free() + } + if respBody != nil { + respBody.Free() + } + if bodyWriter != nil { + bodyWriter.Close() + } + if exec != nil { + exec.Free() + } + (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + FreeConnData(pc.conn) + + // Closing the channel + close(rc.ch) + close(pc.reqch) } -// SetTaskId Set TaskId to the task's userdata as a unique identifier -func SetTaskId(task *hyper.Task, userData TaskId) { - var data = userData - task.SetUserdata(c.Pointer(uintptr(data))) +// FreeConnData frees the connection data +func FreeConnData(conn *ConnData) { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } } From b9f4944b87bbe886d6005a38e2404860e354adeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 9 Aug 2024 18:06:57 +0800 Subject: [PATCH 06/21] WIP(c/http/client): Add request timeout logic --- go.mod | 2 +- go.sum | 4 +- x/http/_demo/timeout/timeout.go | 33 ++++ x/http/client.go | 12 +- x/http/request.go | 11 +- x/http/transport.go | 319 ++++++++++++++++++++------------ 6 files changed, 250 insertions(+), 131 deletions(-) create mode 100644 x/http/_demo/timeout/timeout.go diff --git a/go.mod b/go.mod index 978043d..4082df2 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 +require github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 diff --git a/go.sum b/go.sum index 17cad08..e3abd53 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 h1:il9j5kdSnaoO57XJ8ebSHppPWIJ8iwqgcegOJNkipt4= -github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 h1:VIJ38bCFRIIr62YXyRKkxy6GXYVA6R3xqAb0HkcoUgw= +github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= diff --git a/x/http/_demo/timeout/timeout.go b/x/http/_demo/timeout/timeout.go new file mode 100644 index 0000000..42f8bf8 --- /dev/null +++ b/x/http/_demo/timeout/timeout.go @@ -0,0 +1,33 @@ +package main + +import ( + "fmt" + "io" + "time" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + client := &http.Client{ + Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + //Timeout: time.Second * 5, + } + req, err := http.NewRequest("GET", "https://www.baidu.com", nil) + if err != nil { + fmt.Println(err.Error()) + return + } + resp, err := client.Do(req) + if err != nil { + fmt.Println(err.Error()) + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err.Error()) + return + } + println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/client.go b/x/http/client.go index ac0bc6e..9ce8506 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,7 +1,10 @@ package http +import "time" + type Client struct { Transport RoundTripper + Timeout time.Duration } var DefaultClient = &Client{} @@ -34,13 +37,14 @@ func (c *Client) Do(req *Request) (*Response, error) { } func (c *Client) do(req *Request) (*Response, error) { - return c.send(req, nil) + return c.send(req, c.Timeout) } -func (c *Client) send(req *Request, deadline any) (*Response, error) { - return send(req, c.transport(), deadline) +func (c *Client) send(req *Request, timeout time.Duration) (*Response, error) { + return send(req, c.transport(), timeout) } -func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { +func send(req *Request, rt RoundTripper, timeout time.Duration) (resp *Response, err error) { + req.timeout = timeout return rt.RoundTrip(req) } diff --git a/x/http/request.go b/x/http/request.go index 23e98b4..2e04939 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -16,7 +16,7 @@ type Request struct { Req *hyper.Request Host string Header Header - Timeout time.Duration + timeout time.Duration } func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { @@ -29,10 +29,11 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { return nil, err } return &Request{ - Method: method, - URL: parseURL, - Req: req, - Host: parseURL.Hostname(), + Method: method, + URL: parseURL, + Req: req, + Host: parseURL.Hostname(), + timeout: 0, }, nil } diff --git a/x/http/transport.go b/x/http/transport.go index ed08e45..1eed648 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -17,6 +17,8 @@ type ConnData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect ReadBuf libuv.Buf + TimeoutTimer libuv.Timer + IsCompleted int ReadBufFilled uintptr ReadWaker *hyper.Waker WriteWaker *hyper.Waker @@ -53,9 +55,11 @@ type persistConn struct { //nwrite int64 // bytes written //writech chan writeRequest // written by roundTrip; read by writeLoop //closech chan struct{} // closed when conn closed - conn *ConnData - t *Transport - reqch chan requestAndChan // written by roundTrip; read by readLoop + conn *ConnData + t *Transport + reqch chan requestAndChan // written by roundTrip; read by readLoop + cancelch chan freeChan + timeoutch chan struct{} } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -65,7 +69,7 @@ type incomparable [0]func() type requestAndChan struct { _ incomparable - req *hyper.Request + req *Request ch chan responseAndError // unbuffered; always send in select on callerGone } @@ -77,6 +81,17 @@ type responseAndError struct { err error } +type connAndTimeoutChan struct { + _ incomparable + conn *ConnData + timeoutch chan struct{} +} + +type freeChan struct { + _ incomparable + freech chan struct{} +} + func (t *Transport) RoundTrip(req *Request) (*Response, error) { pconn, err := t.getConn(req) if err != nil { @@ -104,6 +119,18 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } + // If timeout is set, start the timer + timeoutch := make(chan struct{}, 1) + if req.timeout != 0 { + libuv.InitTimer(loop, &conn.TimeoutTimer) + ct := &connAndTimeoutChan{ + conn: conn, + timeoutch: timeoutch, + } + (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) + conn.TimeoutTimer.Start(OnTimeout, uint64(req.timeout.Milliseconds()), 0) + } + libuv.InitTcp(loop, &conn.TcpHandle) //conn.TcpHandle.Data = c.Pointer(conn) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) @@ -116,6 +143,7 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { var res *net.AddrInfo status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { + close(timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } @@ -123,38 +151,57 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) if status != 0 { + close(timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } pconn = &persistConn{ - conn: conn, - t: t, - reqch: make(chan requestAndChan, 1), + conn: conn, + t: t, + reqch: make(chan requestAndChan, 1), + cancelch: make(chan freeChan, 1), + timeoutch: timeoutch, //writech: make(chan writeRequest, 1), //closech: make(chan struct{}), } net.Freeaddrinfo(res) - go pconn.readWriteLoop(loop) + if pconn.conn.IsCompleted != 1 { + go pconn.readWriteLoop(loop) + } return pconn, nil } -func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { - resc := make(chan responseAndError) +func (pc *persistConn) roundTrip(req *Request) (*Response, error) { + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{ - req: req.Req, + req: req, ch: resc, } - + // Determine whether timeout has occurred + if pc.conn.IsCompleted == 1 { + rc := <-pc.reqch // blocking + // Free the resources + FreeResources(nil, nil, nil, nil, pc, rc) + } select { case re := <-resc: if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } if re.err != nil { - return nil, err + return nil, re.err } return re.res, nil + case <-pc.timeoutch: + freech := make(chan struct{}, 1) + pc.cancelch <- freeChan{ + freech: freech, + } + <-freech + close(freech) + return nil, fmt.Errorf("request timeout\n") } } @@ -185,140 +232,156 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { - task := exec.Poll() - if task == nil { - //break - loop.Run(libuv.RUN_ONCE) - continue - } - - switch (TaskId)(uintptr(task.Userdata())) { - case Send: - err := CheckTaskType(task, Send) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - - client := (*hyper.ClientConn)(task.Value()) - task.Free() - - // Send it! - sendTask := client.Send(rc.req) - SetTaskId(sendTask, ReceiveResp) - sendRes := exec.Push(sendTask) - if sendRes != hyper.OK { - rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - - // For this example, no longer need the client - client.Free() - case ReceiveResp: - err := CheckTaskType(task, ReceiveResp) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return + select { + case fc := <-pc.cancelch: + // Free the resources + FreeResources(nil, respBody, bodyWriter, exec, pc, rc) + alive = false + fc.freech <- struct{}{} + return + default: + task := exec.Poll() + if task == nil { + //break + loop.Run(libuv.RUN_ONCE) + continue } - - // Take the results - resp := (*hyper.Response)(task.Value()) - task.Free() - - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody = resp.Body() - - response.Body, bodyWriter = io.Pipe() - - // TODO(spongehah) Replace header operations with using the textproto package - lengthSlice := response.Header["content-length"] - if lengthSlice == nil { - response.ContentLength = 0 - } else { - contentLength := response.Header["content-length"][0] - length, err := strconv.Atoi(contentLength) + switch (TaskId)(uintptr(task.Userdata())) { + case Send: + err := CheckTaskType(task, Send) if err != nil { - rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + rc.ch <- responseAndError{err: err} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) return } - response.ContentLength = int64(length) - } - - rc.ch <- responseAndError{res: &response} - - dataTask := respBody.Data() - SetTaskId(dataTask, ReceiveRespBody) - exec.Push(dataTask) - // No longer need the response - resp.Free() - case ReceiveRespBody: - err := CheckTaskType(task, ReceiveRespBody) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } + client := (*hyper.ClientConn)(task.Value()) + task.Free() - if task.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(task.Value()) - bufLen := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) - if bodyWriter == nil { - rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} + // Send it! + sendTask := client.Send(rc.req.Req) + SetTaskId(sendTask, ReceiveResp) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) return } - _, err := bodyWriter.Write(bytes) // blocking + + // For this example, no longer need the client + client.Free() + case ReceiveResp: + err := CheckTaskType(task, ReceiveResp) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) return } - buf.Free() + + // Take the results + resp := (*hyper.Response)(task.Value()) task.Free() + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() + + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) + + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + respBody = resp.Body() + + response.Body, bodyWriter = io.Pipe() + + // TODO(spongehah) Replace header operations with using the textproto package + lengthSlice := response.Header["content-length"] + if lengthSlice == nil { + response.ContentLength = 0 + } else { + contentLength := response.Header["content-length"][0] + length, err := strconv.Atoi(contentLength) + if err != nil { + rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + response.ContentLength = int64(length) + } + + rc.ch <- responseAndError{res: &response} + + // Response has been returned, stop the timer + pc.conn.IsCompleted = 1 + // Stop the timer + if rc.req.timeout != 0 { + pc.conn.TimeoutTimer.Stop() + (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) + } + dataTask := respBody.Data() SetTaskId(dataTask, ReceiveRespBody) exec.Push(dataTask) - break - } + // No longer need the response + resp.Free() + case ReceiveRespBody: + err := CheckTaskType(task, ReceiveRespBody) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + + if task.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(task.Value()) + bufLen := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) + if bodyWriter == nil { + rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + _, err := bodyWriter.Write(bytes) // blocking + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + buf.Free() + task.Free() + + dataTask := respBody.Data() + SetTaskId(dataTask, ReceiveRespBody) + exec.Push(dataTask) + + break + } + + // We are done with the response body + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } - // We are done with the response body - if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - alive = false - case NotSet: - // A background task for hyper_client completed... - task.Free() + alive = false + case NotSet: + // A background task for hyper_client completed... + task.Free() + } } } //} @@ -454,6 +517,17 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui return hyper.IoPending } +// OnTimeout is the libuv callback for a timeout +func OnTimeout(handle *libuv.Timer) { + ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) + if ct.conn.IsCompleted != 1 { + ct.conn.IsCompleted = 1 + ct.timeoutch <- struct{}{} + } + // Close the timer + (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) +} + // NewIoWithConnReadWrite creates a new IO with read and write callbacks func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { hyperIo := hyper.NewIo() @@ -537,9 +611,16 @@ func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWr (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) FreeConnData(pc.conn) + CloseChannels(rc, pc) +} + +// CloseChannels closes the channels +func CloseChannels(rc requestAndChan, pc *persistConn) { // Closing the channel close(rc.ch) close(pc.reqch) + close(pc.timeoutch) + close(pc.cancelch) } // FreeConnData frees the connection data From 2c9394c1f6d590b686cbb450725573801d0d4aaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 12 Aug 2024 16:47:07 +0800 Subject: [PATCH 07/21] WIP(x/http/client/get): Introducing textproto for header & implementing custom header --- x/http/_demo/headers/headers.go | 45 ++++++++++++++++++++++ x/http/client.go | 5 +++ x/http/header.go | 67 ++++++++++++++++++++++++++++++++- x/http/request.go | 33 ++++++++++++++-- x/http/response.go | 14 ------- x/http/transport.go | 25 +++++++++--- 6 files changed, 164 insertions(+), 25 deletions(-) create mode 100644 x/http/_demo/headers/headers.go diff --git a/x/http/_demo/headers/headers.go b/x/http/_demo/headers/headers.go new file mode 100644 index 0000000..98cda79 --- /dev/null +++ b/x/http/_demo/headers/headers.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/x/http" +) + +func main() { + client := &http.Client{} + req, err := http.NewRequest("GET", "https://jsonplaceholder.typicode.com/comments?postId=1", nil) + if err != nil { + println(err.Error()) + return + } + + //req.Header.Set("accept", "*/*") + //req.Header.Set("accept-encoding", "identity") + //req.Header.Set("cache-control", "no-cache") + //req.Header.Set("pragma", "no-cache") + //req.Header.Set("priority", "u=0, i") + //req.Header.Set("referer", "https://jsonplaceholder.typicode.com/") + //req.Header.Set("sec-ch-ua", "\"Not)A;Brand\";v=\"99\", \"Google Chrome\";v=\"127\", \"Chromium\";v=\"127\"") + //req.Header.Set("sec-ch-ua-mobile", "?0") + //req.Header.Set("sec-ch-ua-platform", "\"macOS\"") + //req.Header.Set("sec-fetch-dest", "document") + //req.Header.Set("sec-fetch-mode", "navigate") + //req.Header.Set("sec-fetch-site", "same-origin") + //req.Header.Set("sec-fetch-user", "?1") + //req.Header.Set("upgrade-insecure-requests", "1") + //req.Header.Set("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36") + + resp, err := client.Do(req) + if err != nil { + println(err.Error()) + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + println(err.Error()) + return + } + fmt.Println(string(body)) +} diff --git a/x/http/client.go b/x/http/client.go index 9ce8506..177b089 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -37,6 +37,11 @@ func (c *Client) Do(req *Request) (*Response, error) { } func (c *Client) do(req *Request) (*Response, error) { + // Add user-defined request headers to hyper.Request + err := req.setHeaders() + if err != nil { + return nil, err + } return c.send(req, c.Timeout) } diff --git a/x/http/header.go b/x/http/header.go index 4710854..aac05ce 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -4,11 +4,74 @@ import ( "fmt" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/x/textproto" "github.com/goplus/llgoexamples/rust/hyper" ) +// A Header represents the key-value pairs in an HTTP header. +// +// The keys should be in canonical form, as returned by +// CanonicalHeaderKey. type Header map[string][]string +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. To use non-canonical keys, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + // AppendToResponseHeader (HeadersForEachCallback) prints each header to the console func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { resp := (*Response)(userdata) @@ -18,8 +81,8 @@ func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va if resp.Header == nil { resp.Header = make(map[string][]string) } - resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) - //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) + resp.Header.Add(nameStr, valueStr) + //resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) return hyper.IterContinue } diff --git a/x/http/request.go b/x/http/request.go index 2e04939..933ae3b 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -28,13 +28,16 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if err != nil { return nil, err } - return &Request{ + request := &Request{ Method: method, URL: parseURL, Req: req, Host: parseURL.Hostname(), + Header: make(Header), timeout: 0, - }, nil + } + request.Header.Set("Host", request.Host) + return request, nil } func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { @@ -49,11 +52,33 @@ func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { return nil, fmt.Errorf("error setting uri %s\n", uri) } - // Set the request headers reqHeaders := req.Headers() if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting headers\n") + return nil, fmt.Errorf("error setting header: Host: %s\n", host) } + return req, nil } + +// setHeaders sets the headers of the request +func (req *Request) setHeaders() error { + headers := req.Req.Headers() + for key, values := range req.Header { + valueLen := len(values) + if valueLen > 1 { + for _, value := range values { + if headers.Add((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(value)[0]), c.Strlen(c.AllocaCStr(value))) != hyper.OK { + return fmt.Errorf("error adding header %s: %s\n", key, value) + } + } + } else if valueLen == 1 { + if headers.Set((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(values[0])[0]), c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { + return fmt.Errorf("error setting header %s: %s\n", key, values[0]) + } + } else { + return fmt.Errorf("error setting header %s: empty value\n", key) + } + } + return nil +} diff --git a/x/http/response.go b/x/http/response.go index 2f3a641..e69c273 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -11,17 +11,3 @@ type Response struct { Body io.ReadCloser ContentLength int64 } - -// AppendToResponseBody (BodyForEachCallback) appends the body to the response -//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// resp := (*Response)(userdata) -// len := chunk.Len() -// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) -// _, err := resp.respBodyWriter.Write(buf) -// resp.ContentLength += int64(len) -// if err != nil { -// fmt.Printf("Failed to write response body: %v\n", err) -// return hyper.IterBreak -// } -// return hyper.IterContinue -//} diff --git a/x/http/transport.go b/x/http/transport.go index 1eed648..ec5ae02 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -184,6 +184,7 @@ func (pc *persistConn) roundTrip(req *Request) (*Response, error) { rc := <-pc.reqch // blocking // Free the resources FreeResources(nil, nil, nil, nil, pc, rc) + return nil, fmt.Errorf("request timeout\n") } select { case re := <-resc: @@ -297,12 +298,26 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { response.Body, bodyWriter = io.Pipe() - // TODO(spongehah) Replace header operations with using the textproto package - lengthSlice := response.Header["content-length"] - if lengthSlice == nil { - response.ContentLength = 0 + //// TODO(spongehah) Replace header operations with using the textproto package + //lengthSlice := response.Header["content-length"] + //if lengthSlice == nil { + // response.ContentLength = -1 + //} else { + // contentLength := response.Header["content-length"][0] + // length, err := strconv.Atoi(contentLength) + // if err != nil { + // rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + // // Free the resources + // FreeResources(task, respBody, bodyWriter, exec, pc, rc) + // return + // } + // response.ContentLength = int64(length) + //} + + contentLength := response.Header.Get("content-length") + if contentLength == "" { + response.ContentLength = -1 } else { - contentLength := response.Header["content-length"][0] length, err := strconv.Atoi(contentLength) if err != nil { rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} From 5744fd69371df5ebc36d29579e63fd84e346f007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 13 Aug 2024 18:16:44 +0800 Subject: [PATCH 08/21] WIP(x/http/client/get): Extract the readTransfer function and complete its content --- x/http/header.go | 122 ++++++++-------- x/http/request.go | 3 +- x/http/response.go | 30 +++- x/http/transfer.go | 332 ++++++++++++++++++++++++++++++++++++++++++++ x/http/transport.go | 61 +++----- 5 files changed, 444 insertions(+), 104 deletions(-) create mode 100644 x/http/transfer.go diff --git a/x/http/header.go b/x/http/header.go index aac05ce..0533ed7 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/x/textproto" "github.com/goplus/llgoexamples/rust/hyper" ) @@ -18,59 +17,68 @@ type Header map[string][]string // It appends to any existing values associated with key. // The key is case insensitive; it is canonicalized by // CanonicalHeaderKey. -func (h Header) Add(key, value string) { - textproto.MIMEHeader(h).Add(key, value) -} - -// Set sets the header entries associated with key to the -// single element value. It replaces any existing values -// associated with key. The key is case insensitive; it is -// canonicalized by textproto.CanonicalMIMEHeaderKey. -// To use non-canonical keys, assign to the map directly. -func (h Header) Set(key, value string) { - textproto.MIMEHeader(h).Set(key, value) -} - -// Get gets the first value associated with the given key. If -// there are no values associated with the key, Get returns "". -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -// used to canonicalize the provided key. Get assumes that all -// keys are stored in canonical form. To use non-canonical keys, -// access the map directly. -func (h Header) Get(key string) string { - return textproto.MIMEHeader(h).Get(key) -} - -// Values returns all values associated with the given key. -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -// used to canonicalize the provided key. To use non-canonical -// keys, access the map directly. -// The returned slice is not a copy. -func (h Header) Values(key string) []string { - return textproto.MIMEHeader(h).Values(key) -} - -// get is like Get, but key must already be in CanonicalHeaderKey form. -func (h Header) get(key string) string { - if v := h[key]; len(v) > 0 { - return v[0] - } - return "" -} - -// has reports whether h has the provided key defined, even if it's -// set to 0-length slice. -func (h Header) has(key string) bool { - _, ok := h[key] - return ok -} - -// Del deletes the values associated with key. -// The key is case insensitive; it is canonicalized by -// CanonicalHeaderKey. -func (h Header) Del(key string) { - textproto.MIMEHeader(h).Del(key) -} +//func (h Header) Add(key, value string) { +// textproto.MIMEHeader(h).Add(key, value) +//} +// +//// Set sets the header entries associated with key to the +//// single element value. It replaces any existing values +//// associated with key. The key is case insensitive; it is +//// canonicalized by textproto.CanonicalMIMEHeaderKey. +//// To use non-canonical keys, assign to the map directly. +//func (h Header) Set(key, value string) { +// textproto.MIMEHeader(h).Set(key, value) +//} +// +//// Get gets the first value associated with the given key. If +//// there are no values associated with the key, Get returns "". +//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +//// used to canonicalize the provided key. Get assumes that all +//// keys are stored in canonical form. To use non-canonical keys, +//// access the map directly. +//func (h Header) Get(key string) string { +// return textproto.MIMEHeader(h).Get(key) +//} +// +//// Values returns all values associated with the given key. +//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +//// used to canonicalize the provided key. To use non-canonical +//// keys, access the map directly. +//// The returned slice is not a copy. +//func (h Header) Values(key string) []string { +// return textproto.MIMEHeader(h).Values(key) +//} +// +//// get is like Get, but key must already be in CanonicalHeaderKey form. +//func (h Header) get(key string) string { +// if v := h[key]; len(v) > 0 { +// return v[0] +// } +// return "" +//} +// +//// has reports whether h has the provided key defined, even if it's +//// set to 0-length slice. +//func (h Header) has(key string) bool { +// _, ok := h[key] +// return ok +//} +// +//// Del deletes the values associated with key. +//// The key is case insensitive; it is canonicalized by +//// CanonicalHeaderKey. +//func (h Header) Del(key string) { +// textproto.MIMEHeader(h).Del(key) +//} +// +//// CanonicalHeaderKey returns the canonical format of the +//// header key s. The canonicalization converts the first +//// letter and any letter following a hyphen to upper case; +//// the rest are converted to lowercase. For example, the +//// canonical key for "accept-encoding" is "Accept-Encoding". +//// If s contains a space or invalid header field bytes, it is +//// returned without modifications. +//func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } // AppendToResponseHeader (HeadersForEachCallback) prints each header to the console func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { @@ -79,10 +87,10 @@ func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) if resp.Header == nil { - resp.Header = make(map[string][]string) + resp.Header = make(Header) } - resp.Header.Add(nameStr, valueStr) - //resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + //resp.Header.Add(nameStr, valueStr) + resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) return hyper.IterContinue } diff --git a/x/http/request.go b/x/http/request.go index 933ae3b..1d219a1 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -36,7 +36,8 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { Header: make(Header), timeout: 0, } - request.Header.Set("Host", request.Host) + //request.Header.Set("Host", request.Host) + request.Header["Host"] = []string{request.Host} return request, nil } diff --git a/x/http/response.go b/x/http/response.go index e69c273..a08d77e 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -1,13 +1,39 @@ package http import ( + "fmt" "io" + "strconv" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" ) type Response struct { - Status string - StatusCode int + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 Header Header Body io.ReadCloser ContentLength int64 + Trailer Header + Chunked bool + Request *Request +} + +func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { + rp := hyperResp.ReasonPhrase() + rpLen := hyperResp.ReasonPhraseLen() + + resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + resp.StatusCode = int(hyperResp.Status()) + + version := int(hyperResp.Version()) + resp.ProtoMajor, resp.ProtoMinor = splitTwoDigitNumber(version) + resp.Proto = fmt.Sprintf("HTTP/%d.%d", resp.ProtoMajor, resp.ProtoMinor) + + headers := hyperResp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(resp)) } diff --git a/x/http/transfer.go b/x/http/transfer.go new file mode 100644 index 0000000..5157324 --- /dev/null +++ b/x/http/transfer.go @@ -0,0 +1,332 @@ +package http +// +//import ( +// "fmt" +// "io" +// "net/textproto" +// "strconv" +// "strings" +// +// "github.com/goplus/llgoexamples/rust/hyper" +//) +// +//type transferReader struct { +// // Input +// Header Header +// StatusCode int +// RequestMethod string +// ProtoMajor int +// ProtoMinor int +// // Output +// Body io.ReadCloser +// ContentLength int64 +// Chunked bool +// Close bool +// Trailer Header +//} +// +//// unsupportedTEError reports unsupported transfer-encodings. +//type unsupportedTEError struct { +// err string +//} +// +//func (uste *unsupportedTEError) Error() string { +// return uste.err +//} +// +//func readTransfer(resp *Response, hyperResp *hyper.Response) (err error) { +// //// TODO(spongehah) Replace header operations with using the textproto package +// //lengthSlice := resp.Header["content-length"] +// //if lengthSlice == nil { +// // resp.ContentLength = -1 +// //} else { +// // contentLength := resp.Header["content-length"][0] +// // length, err := strconv.Atoi(contentLength) +// // if err != nil { +// // return err +// // } +// // resp.ContentLength = int64(length) +// //} +// +// t := &transferReader{ +// Header: resp.Header, +// StatusCode: resp.StatusCode, +// RequestMethod: resp.Request.Method, +// ProtoMajor: resp.ProtoMajor, +// ProtoMinor: resp.ProtoMinor, +// } +// +// // Transfer-Encoding: chunked, and overriding Content-Length. +// if err = t.parseTransferEncoding(); err != nil { +// return err +// } +// +// realLength, err := fixLength(true, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) +// if err != nil { +// return err +// } +// if t.RequestMethod == "HEAD" { +// if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { +// return err +// } else { +// t.ContentLength = n +// } +// } else { +// t.ContentLength = realLength +// } +// +// // Trailer +// t.Trailer, err = fixTrailer(t.Header, t.Chunked) +// +// // If there is no Content-Length or chunked Transfer-Encoding on a *Response +// // and the status is not 1xx, 204 or 304, then the body is unbounded. +// // See RFC 7230, section 3.3. +// if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { +// // Unbounded body. +// t.Close = true +// } +// +// return nil +//} +// +//// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +//func (t *transferReader) parseTransferEncoding() error { +// raw, present := t.Header["Transfer-Encoding"] +// if !present { +// return nil +// } +// delete(t.Header, "Transfer-Encoding") +// +// // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. +// if !t.protoAtLeast(1, 1) { +// return nil +// } +// +// // Like nginx, we only support a single Transfer-Encoding header field, and +// // only if set to "chunked". This is one of the most security sensitive +// // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it +// // strict and simple. +// if len(raw) != 1 { +// return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} +// } +// if !equalFold(raw[0], "chunked") { +// return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} +// } +// +// // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field +// // in any message that contains a Transfer-Encoding header field." +// // +// // but also: "If a message is received with both a Transfer-Encoding and a +// // Content-Length header field, the Transfer-Encoding overrides the +// // Content-Length. Such a message might indicate an attempt to perform +// // request smuggling (Section 9.5) or response splitting (Section 9.4) and +// // ought to be handled as an error. A sender MUST remove the received +// // Content-Length field prior to forwarding such a message downstream." +// // +// // Reportedly, these appear in the wild. +// delete(t.Header, "Content-Length") +// +// t.Chunked = true +// return nil +//} +// +//func (t *transferReader) protoAtLeast(m, n int) bool { +// return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +//} +// +//// equalFold is strings.EqualFold, ASCII only. It reports whether s and t +//// are equal, ASCII-case-insensitively. +//func equalFold(s, t string) bool { +// if len(s) != len(t) { +// return false +// } +// for i := 0; i < len(s); i++ { +// if lower(s[i]) != lower(t[i]) { +// return false +// } +// } +// return true +//} +// +//// Determine the expected body length, using RFC 7230 Section 3.3. This +//// function is not a method, because ultimately it should be shared by +//// ReadResponse and ReadRequest. +//func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { +// isRequest := !isResponse +// contentLens := header["Content-Length"] +// +// // Hardening against HTTP request smuggling +// if len(contentLens) > 1 { +// // Per RFC 7230 Section 3.3.2, prevent multiple +// // Content-Length headers if they differ in value. +// // If there are dups of the value, remove the dups. +// // See Issue 16490. +// first := textproto.TrimString(contentLens[0]) +// for _, ct := range contentLens[1:] { +// if first != textproto.TrimString(ct) { +// return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) +// } +// } +// +// // deduplicate Content-Length +// header.Del("Content-Length") +// header.Add("Content-Length", first) +// +// contentLens = header["Content-Length"] +// } +// +// // Logic based on response type or status +// if isResponse && noResponseBodyExpected(requestMethod) { +// return 0, nil +// } +// if status/100 == 1 { +// return 0, nil +// } +// switch status { +// case 204, 304: +// return 0, nil +// } +// +// // Logic based on Transfer-Encoding +// if chunked { +// return -1, nil +// } +// +// // Logic based on Content-Length +// var cl string +// if len(contentLens) == 1 { +// cl = textproto.TrimString(contentLens[0]) +// } +// if cl != "" { +// n, err := parseContentLength(cl) +// if err != nil { +// return -1, err +// } +// return n, nil +// } +// header.Del("Content-Length") +// +// if isRequest { +// // RFC 7230 neither explicitly permits nor forbids an +// // entity-body on a GET request so we permit one if +// // declared, but we default to 0 here (not -1 below) +// // if there's no mention of a body. +// // Likewise, all other request methods are assumed to have +// // no body if neither Transfer-Encoding chunked nor a +// // Content-Length are set. +// return 0, nil +// } +// +// // Body-EOF logic based on other methods (like closing, or chunked coding) +// return -1, nil +//} +// +//// parseContentLength trims whitespace from s and returns -1 if no value +//// is set, or the value if it's >= 0. +//func parseContentLength(cl string) (int64, error) { +// cl = textproto.TrimString(cl) +// if cl == "" { +// return -1, nil +// } +// n, err := strconv.ParseUint(cl, 10, 63) +// if err != nil { +// return 0, badStringError("bad Content-Length", cl) +// } +// return int64(n), nil +// +//} +// +//// Parse the trailer header. +//func fixTrailer(header Header, chunked bool) (Header, error) { +// vv, ok := header["Trailer"] +// if !ok { +// return nil, nil +// } +// if !chunked { +// // Trailer and no chunking: +// // this is an invalid use case for trailer header. +// // Nevertheless, no error will be returned and we +// // let users decide if this is a valid HTTP message. +// // The Trailer header will be kept in Response.Header +// // but not populate Response.Trailer. +// // See issue #27197. +// return nil, nil +// } +// header.Del("Trailer") +// +// trailer := make(Header) +// var err error +// for _, v := range vv { +// foreachHeaderElement(v, func(key string) { +// key = CanonicalHeaderKey(key) +// switch key { +// case "Transfer-Encoding", "Trailer", "Content-Length": +// if err == nil { +// err = badStringError("bad trailer key", key) +// return +// } +// } +// trailer[key] = nil +// }) +// } +// if err != nil { +// return nil, err +// } +// if len(trailer) == 0 { +// return nil, nil +// } +// return trailer, nil +//} +// +//// splitTwoDigitNumber splits a two-digit number into two digits. +func splitTwoDigitNumber(num int) (int, int) { + tens := num / 10 + ones := num % 10 + return tens, ones +} +// +//// lower returns the ASCII lowercase version of b. +//func lower(b byte) byte { +// if 'A' <= b && b <= 'Z' { +// return b + ('a' - 'A') +// } +// return b +//} +// +//// foreachHeaderElement splits v according to the "#rule" construction +//// in RFC 7230 section 7 and calls fn for each non-empty element. +//func foreachHeaderElement(v string, fn func(string)) { +// v = textproto.TrimString(v) +// if v == "" { +// return +// } +// if !strings.Contains(v, ",") { +// fn(v) +// return +// } +// for _, f := range strings.Split(v, ",") { +// if f = textproto.TrimString(f); f != "" { +// fn(f) +// } +// } +//} +// +//func noResponseBodyExpected(requestMethod string) bool { +// return requestMethod == "HEAD" +//} +// +//func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } +// +//// bodyAllowedForStatus reports whether a given response status code +//// permits a body. See RFC 7230, section 3.3. +//func bodyAllowedForStatus(status int) bool { +// switch { +// case status >= 100 && status <= 199: +// return false +// case status == 204: +// return false +// case status == 304: +// return false +// } +// return true +//} diff --git a/x/http/transport.go b/x/http/transport.go index ec5ae02..b9f845c 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -3,7 +3,6 @@ package http import ( "fmt" "io" - "strconv" "unsafe" "github.com/goplus/llgo/c" @@ -229,7 +228,11 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Poll all ready tasks and act on them... rc := <-pc.reqch // blocking alive := true - var response Response + resp := &Response{ + Request: rc.req, + Header: make(Header), + Trailer: make(Header), + } var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { @@ -283,52 +286,22 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } // Take the results - resp := (*hyper.Response)(task.Value()) + hyperResp := (*hyper.Response)(task.Value()) task.Free() - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody = resp.Body() - - response.Body, bodyWriter = io.Pipe() - - //// TODO(spongehah) Replace header operations with using the textproto package - //lengthSlice := response.Header["content-length"] - //if lengthSlice == nil { - // response.ContentLength = -1 - //} else { - // contentLength := response.Header["content-length"][0] - // length, err := strconv.Atoi(contentLength) - // if err != nil { - // rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} - // // Free the resources - // FreeResources(task, respBody, bodyWriter, exec, pc, rc) - // return - // } - // response.ContentLength = int64(length) + readResponseLineAndHeader(resp, hyperResp) + //err = readTransfer(resp, hyperResp) + //if err != nil { + // rc.ch <- responseAndError{err: err} + // // Free the resources + // FreeResources(task, respBody, bodyWriter, exec, pc, rc) + // return //} - contentLength := response.Header.Get("content-length") - if contentLength == "" { - response.ContentLength = -1 - } else { - length, err := strconv.Atoi(contentLength) - if err != nil { - rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - response.ContentLength = int64(length) - } + respBody = hyperResp.Body() + resp.Body, bodyWriter = io.Pipe() - rc.ch <- responseAndError{res: &response} + rc.ch <- responseAndError{res: resp} // Response has been returned, stop the timer pc.conn.IsCompleted = 1 @@ -343,7 +316,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { exec.Push(dataTask) // No longer need the response - resp.Free() + hyperResp.Free() case ReceiveRespBody: err := CheckTaskType(task, ReceiveRespBody) if err != nil { From 2e9e3384d9913d506f5108059c509dc978609c89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 15 Aug 2024 18:16:45 +0800 Subject: [PATCH 09/21] WIP(x/http/client): Code tweaks made; Post request runs successfully. Initial redirection check implemented. --- x/http/_demo/get/get.go | 1 + x/http/_demo/headers/headers.go | 10 +- x/http/_demo/post/post.go | 24 + x/http/_demo/upload/example.txt | 1 + x/http/_demo/upload/upload.go | 24 + x/http/client.go | 179 ++++++- x/http/header.go | 132 +++--- x/http/request.go | 164 +++++-- x/http/response.go | 58 ++- x/http/transfer.go | 808 +++++++++++++++++++------------- x/http/transport.go | 33 +- 11 files changed, 974 insertions(+), 460 deletions(-) create mode 100644 x/http/_demo/post/post.go create mode 100755 x/http/_demo/upload/example.txt create mode 100644 x/http/_demo/upload/upload.go diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index c8460e4..bff1bd1 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -14,6 +14,7 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + fmt.Println(resp.Proto) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/http/_demo/headers/headers.go b/x/http/_demo/headers/headers.go index 98cda79..2672a66 100644 --- a/x/http/_demo/headers/headers.go +++ b/x/http/_demo/headers/headers.go @@ -4,19 +4,19 @@ import ( "fmt" "io" - "github.com/goplus/llgo/x/http" + "github.com/goplus/llgoexamples/x/http" ) func main() { client := &http.Client{} - req, err := http.NewRequest("GET", "https://jsonplaceholder.typicode.com/comments?postId=1", nil) + req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { println(err.Error()) return } //req.Header.Set("accept", "*/*") - //req.Header.Set("accept-encoding", "identity") + req.Header.Set("accept-encoding", "identity") //req.Header.Set("cache-control", "no-cache") //req.Header.Set("pragma", "no-cache") //req.Header.Set("priority", "u=0, i") @@ -28,7 +28,7 @@ func main() { //req.Header.Set("sec-fetch-mode", "navigate") //req.Header.Set("sec-fetch-site", "same-origin") //req.Header.Set("sec-fetch-user", "?1") - //req.Header.Set("upgrade-insecure-requests", "1") + ////req.Header.Set("upgrade-insecure-requests", "1") //req.Header.Set("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36") resp, err := client.Do(req) @@ -36,10 +36,12 @@ func main() { println(err.Error()) return } + resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { println(err.Error()) return } fmt.Println(string(body)) + defer resp.Body.Close() } diff --git a/x/http/_demo/post/post.go b/x/http/_demo/post/post.go new file mode 100644 index 0000000..a86e805 --- /dev/null +++ b/x/http/_demo/post/post.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + resp, err := http.Post("https://jsonplaceholder.typicode.com/posts", "application/json; charset=UTF-8", nil) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status) + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/_demo/upload/example.txt b/x/http/_demo/upload/example.txt new file mode 100755 index 0000000..1253cd4 --- /dev/null +++ b/x/http/_demo/upload/example.txt @@ -0,0 +1 @@ +hello upload \ No newline at end of file diff --git a/x/http/_demo/upload/upload.go b/x/http/_demo/upload/upload.go new file mode 100644 index 0000000..c6bb391 --- /dev/null +++ b/x/http/_demo/upload/upload.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + resp, err := http.Post("http://httpbin.org/post", "", nil) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status) + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/client.go b/x/http/client.go index 177b089..72e9688 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,6 +1,12 @@ package http -import "time" +import ( + "errors" + "io" + "net/url" + "strings" + "time" +) type Client struct { Transport RoundTripper @@ -32,17 +38,93 @@ func (c *Client) Get(url string) (*Response, error) { return c.Do(req) } -func (c *Client) Do(req *Request) (*Response, error) { - return c.do(req) +func Post(url, contentType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, contentType, body) } -func (c *Client) do(req *Request) (*Response, error) { - // Add user-defined request headers to hyper.Request - err := req.setHeaders() +func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) { + req, err := NewRequest("POST", url, body) if err != nil { return nil, err } - return c.send(req, c.Timeout) + req.Header.Set("Content-Type", contentType) + return c.Do(req) +} + +func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req) +} + +var testHookClientDoResult func(retres *Response, reterr error) + +func (c *Client) do(req *Request) (retres *Response, reterr error) { + if testHookClientDoResult != nil { + defer func() { testHookClientDoResult(retres, reterr) }() + } + + if req.URL == nil { + req.closeBody() + return nil, &url.Error{ + Op: urlErrorOp(req.Method), + Err: errors.New("http: nil Request.URL"), + } + } + var ( + //deadline = c.deadline() + reqs []*Request + resp *Response + //copyHeaders = c.makeHeadersCopier(req) + reqBodyClosed = false // have we closed the current req.Body? + + // Redirect behavior: + //redirectMethod string + //includeBody bool + ) + uerr := func(err error) error { + // the body may have been closed already by c.send() + if !reqBodyClosed { + req.closeBody() + } + var urlStr string + if resp != nil && resp.Request != nil { + urlStr = stripPassword(resp.Request.URL) + } else { + urlStr = stripPassword(req.URL) + } + return &url.Error{ + Op: urlErrorOp(reqs[0].Method), + URL: urlStr, + Err: err, + } + } + + // For all but the first request, create the next + // request hop and replace req. + for { + if len(reqs) > 0 { + + } + + reqs = append(reqs, req) + var err error + if resp, err = c.send(req, c.Timeout); err != nil { + // c.send() always closes req.Body + reqBodyClosed = true + return nil, uerr(err) + } + + var shouldRedirect bool + //redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) + _, shouldRedirect, _ = redirectBehavior(req.Method, resp, reqs[0]) + if !shouldRedirect { + return resp, nil + } else { + // TODO(spongehah) + return nil, errors.New("TODO: redirect not implemented") + } + + req.closeBody() + } } func (c *Client) send(req *Request, timeout time.Duration) (*Response, error) { @@ -53,3 +135,86 @@ func send(req *Request, rt RoundTripper, timeout time.Duration) (resp *Response, req.timeout = timeout return rt.RoundTrip(req) } + +// redirectBehavior describes what should happen when the +// client encounters a 3xx status code from the server. +func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect, includeBody bool) { + switch resp.StatusCode { + case 301, 302, 303: + redirectMethod = reqMethod + shouldRedirect = true + includeBody = false + + // RFC 2616 allowed automatic redirection only with GET and + // HEAD requests. RFC 7231 lifts this restriction, but we still + // restrict other methods to GET to maintain compatibility. + // See Issue 18570. + if reqMethod != "GET" && reqMethod != "HEAD" { + redirectMethod = "GET" + } + case 307, 308: + redirectMethod = reqMethod + shouldRedirect = true + includeBody = true + + if ireq.GetBody == nil && ireq.outgoingLength() != 0 { + // We had a request body, and 307/308 require + // re-sending it, but GetBody is not defined. So just + // return this response to the user instead of an + // error, like we did in Go 1.7 and earlier. + shouldRedirect = false + } + } + return redirectMethod, shouldRedirect, includeBody +} + +// outgoingLength reports the Content-Length of this outgoing (Client) request. +// It maps 0 into -1 (unknown) when the Body is non-nil. +func (r *Request) outgoingLength() int64 { + if r.Body == nil || r.Body == NoBody { + return 0 + } + if r.ContentLength != 0 { + return r.ContentLength + } + return -1 +} + +// urlErrorOp returns the (*url.Error).Op value to use for the +// provided (*Request).Method value. +func urlErrorOp(method string) string { + if method == "" { + return "Get" + } + if lowerMethod, ok := ToLower(method); ok { + return method[:1] + lowerMethod[1:] + } + return method +} + +// ToLower returns the lowercase version of s if s is ASCII and printable. +func ToLower(s string) (lower string, ok bool) { + if !IsPrint(s) { + return "", false + } + return strings.ToLower(s), true +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +func stripPassword(u *url.URL) string { + _, passSet := u.User.Password() + if passSet { + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) + } + return u.String() +} diff --git a/x/http/header.go b/x/http/header.go index 0533ed7..076db0f 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -2,6 +2,7 @@ package http import ( "fmt" + "net/textproto" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -17,80 +18,79 @@ type Header map[string][]string // It appends to any existing values associated with key. // The key is case insensitive; it is canonicalized by // CanonicalHeaderKey. -//func (h Header) Add(key, value string) { -// textproto.MIMEHeader(h).Add(key, value) -//} -// -//// Set sets the header entries associated with key to the -//// single element value. It replaces any existing values -//// associated with key. The key is case insensitive; it is -//// canonicalized by textproto.CanonicalMIMEHeaderKey. -//// To use non-canonical keys, assign to the map directly. -//func (h Header) Set(key, value string) { -// textproto.MIMEHeader(h).Set(key, value) -//} -// -//// Get gets the first value associated with the given key. If -//// there are no values associated with the key, Get returns "". -//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -//// used to canonicalize the provided key. Get assumes that all -//// keys are stored in canonical form. To use non-canonical keys, -//// access the map directly. -//func (h Header) Get(key string) string { -// return textproto.MIMEHeader(h).Get(key) -//} -// -//// Values returns all values associated with the given key. -//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -//// used to canonicalize the provided key. To use non-canonical -//// keys, access the map directly. -//// The returned slice is not a copy. -//func (h Header) Values(key string) []string { -// return textproto.MIMEHeader(h).Values(key) -//} -// -//// get is like Get, but key must already be in CanonicalHeaderKey form. -//func (h Header) get(key string) string { -// if v := h[key]; len(v) > 0 { -// return v[0] -// } -// return "" -//} -// -//// has reports whether h has the provided key defined, even if it's -//// set to 0-length slice. -//func (h Header) has(key string) bool { -// _, ok := h[key] -// return ok -//} -// -//// Del deletes the values associated with key. -//// The key is case insensitive; it is canonicalized by -//// CanonicalHeaderKey. -//func (h Header) Del(key string) { -// textproto.MIMEHeader(h).Del(key) -//} -// -//// CanonicalHeaderKey returns the canonical format of the -//// header key s. The canonicalization converts the first -//// letter and any letter following a hyphen to upper case; -//// the rest are converted to lowercase. For example, the -//// canonical key for "accept-encoding" is "Accept-Encoding". -//// If s contains a space or invalid header field bytes, it is -//// returned without modifications. -//func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. To use non-canonical keys, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } // AppendToResponseHeader (HeadersForEachCallback) prints each header to the console func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { resp := (*Response)(userdata) - nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) - valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) + nameStr := c.GoString((*int8)(c.Pointer(name)), nameLen) + valueStr := c.GoString((*int8)(c.Pointer(value)), valueLen) if resp.Header == nil { resp.Header = make(Header) } - //resp.Header.Add(nameStr, valueStr) - resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + resp.Header.Add(nameStr, valueStr) return hyper.IterContinue } diff --git a/x/http/request.go b/x/http/request.go index 1d219a1..c84cd9b 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -7,74 +7,169 @@ import ( "time" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" ) type Request struct { - Method string - URL *url.URL - Req *hyper.Request - Host string - Header Header - timeout time.Duration + Method string + URL *url.URL + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + Header Header + Body io.ReadCloser + GetBody func() (io.ReadCloser, error) + ContentLength int64 + TransferEncoding []string + Close bool + Host string + timeout time.Duration } +type postBody struct { + data []byte + len uintptr + readLen uintptr +} + +type uploadBody struct { + fd c.Int + buf []byte + len uintptr +} + +var DefaultChunkSize uintptr = 8192 + func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { - parseURL, err := url.Parse(urlStr) - if err != nil { - return nil, err - } - req, err := newHyperRequest(method, parseURL) + u, err := url.Parse(urlStr) if err != nil { return nil, err } + //rc, ok := body.(io.ReadCloser) + //if !ok && body != nil { + // rc = io.NopCloser(body) + //} request := &Request{ - Method: method, - URL: parseURL, - Req: req, - Host: parseURL.Hostname(), - Header: make(Header), + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Host: u.Host, + //Body: rc, timeout: 0, } - //request.Header.Set("Host", request.Host) - request.Header["Host"] = []string{request.Host} + request.Header.Set("Host", request.Host) + return request, nil } -func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { - host := URL.Hostname() - uri := URL.RequestURI() +func PrintInformational(userdata c.Pointer, resp *hyper.Response) { + status := resp.Status() + fmt.Println("Informational (1xx): ", status) +} + +func SetPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + //upload := (*uploadBody)(userdata) + //res := os.Read(upload.fd, c.Pointer(&upload.buf[0]), upload.len) + //if res > 0 { + // *chunk = hyper.CopyBuf(&upload.buf[0], uintptr(res)) + // return hyper.PollReady + //} + //if res == 0 { + // *chunk = nil + // os.Close(upload.fd) + // return hyper.PollReady + //} + body := (*postBody)(userdata) + if body.len > 0 { + if body.len > DefaultChunkSize { + *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) + body.readLen += DefaultChunkSize + body.len -= DefaultChunkSize + } else { + *chunk = hyper.CopyBuf(&body.data[body.readLen], body.len) + body.readLen += body.len + body.len = 0 + } + return hyper.PollReady + } + if body.len == 0 { + *chunk = nil + return hyper.PollReady + } + + fmt.Printf("error reading upload file: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} + +func newHyperRequest(req *Request) (*hyper.Request, error) { + host := req.Host + uri := req.URL.Path + method := req.Method // Prepare the request - req := hyper.NewRequest() + hyperReq := hyper.NewRequest() // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { + if hyperReq.SetMethod(&[]byte(method)[0], c.Strlen(c.AllocaCStr(method))) != hyper.OK { return nil, fmt.Errorf("error setting method %s\n", method) } - if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + if hyperReq.SetURI(&[]byte(uri)[0], c.Strlen(c.AllocaCStr(uri))) != hyper.OK { return nil, fmt.Errorf("error setting uri %s\n", uri) } // Set the request headers - reqHeaders := req.Headers() - if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + reqHeaders := hyperReq.Headers() + if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { return nil, fmt.Errorf("error setting header: Host: %s\n", host) } - return req, nil + if method == "POST" { + //var upload uploadBody + //upload.fd = os.Open(c.Str("/Users/spongehah/go/src/llgo/x/http/_demo/post/example.txt"), os.O_RDONLY) + //if upload.fd < 0 { + // return nil, fmt.Errorf("error opening file to upload: %s\n", c.GoString(c.Strerror(os.Errno))) + //} + //upload.len = 8192 + //upload.buf = make([]byte, upload.len) + req.Header.Set("expect", "100-continue") + hyperReq.OnInformational(PrintInformational, nil) + postData := []byte(`{"id":1,"title":"foo","body":"bar","userId":"1"}`) + + reqBody := &postBody{ + data: postData, + len: uintptr(len(postData)), + } + + hyperReqBody := hyper.NewBody() + hyperReqBody.SetUserdata(c.Pointer(reqBody)) + //hyperReqBody.SetUserdata(c.Pointer(&upload)) + hyperReqBody.SetDataFunc(SetPostData) + hyperReq.SetBody(hyperReqBody) + } + + // Add user-defined request headers to hyper.Request + err := req.setHeaders(hyperReq) + if err != nil { + return nil, err + } + + return hyperReq, nil } // setHeaders sets the headers of the request -func (req *Request) setHeaders() error { - headers := req.Req.Headers() +func (req *Request) setHeaders(hyperReq *hyper.Request) error { + headers := hyperReq.Headers() for key, values := range req.Header { valueLen := len(values) if valueLen > 1 { for _, value := range values { - if headers.Add((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(value)[0]), c.Strlen(c.AllocaCStr(value))) != hyper.OK { + if headers.Add(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(value)[0], c.Strlen(c.AllocaCStr(value))) != hyper.OK { return fmt.Errorf("error adding header %s: %s\n", key, value) } } } else if valueLen == 1 { - if headers.Set((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(values[0])[0]), c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { + if headers.Set(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(values[0])[0], c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { return fmt.Errorf("error setting header %s: %s\n", key, values[0]) } } else { @@ -83,3 +178,10 @@ func (req *Request) setHeaders() error { } return nil } + +func (r *Request) closeBody() error { + if r.Body == nil { + return nil + } + return r.Body.Close() +} diff --git a/x/http/response.go b/x/http/response.go index a08d77e..c99bade 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -10,24 +10,43 @@ import ( ) type Response struct { - Status string // e.g. "200 OK" - StatusCode int // e.g. 200 - Proto string // e.g. "HTTP/1.0" - ProtoMajor int // e.g. 1 - ProtoMinor int // e.g. 0 - Header Header - Body io.ReadCloser - ContentLength int64 - Trailer Header - Chunked bool - Request *Request + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 + Header Header + Body io.ReadCloser + ContentLength int64 + TransferEncoding []string + Close bool + Trailer Header + Request *Request } +func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { + resp := &Response{ + Request: req, + Header: make(Header), + Trailer: make(Header), + } + readResponseLineAndHeader(resp, hyperResp) + + fixPragmaCacheControl(req.Header) + + err := readTransfer(resp) + if err != nil { + return nil, err + } + return resp, nil +} + +// readResponseLineAndHeader reads the response line and header from hyper response. func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { rp := hyperResp.ReasonPhrase() rpLen := hyperResp.ReasonPhraseLen() - resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + c.GoString((*int8)(c.Pointer(rp)), rpLen) resp.StatusCode = int(hyperResp.Status()) version := int(hyperResp.Version()) @@ -37,3 +56,18 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers := hyperResp.Headers() headers.Foreach(AppendToResponseHeader, c.Pointer(resp)) } + +// RFC 7234, section 5.4: Should treat +// +// Pragma: no-cache +// +// like +// +// Cache-Control: no-cache +func fixPragmaCacheControl(header Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} diff --git a/x/http/transfer.go b/x/http/transfer.go index 5157324..70f082e 100644 --- a/x/http/transfer.go +++ b/x/http/transfer.go @@ -1,332 +1,488 @@ package http -// -//import ( -// "fmt" -// "io" -// "net/textproto" -// "strconv" -// "strings" -// -// "github.com/goplus/llgoexamples/rust/hyper" -//) -// -//type transferReader struct { -// // Input -// Header Header -// StatusCode int -// RequestMethod string -// ProtoMajor int -// ProtoMinor int -// // Output -// Body io.ReadCloser -// ContentLength int64 -// Chunked bool -// Close bool -// Trailer Header -//} -// -//// unsupportedTEError reports unsupported transfer-encodings. -//type unsupportedTEError struct { -// err string -//} -// -//func (uste *unsupportedTEError) Error() string { -// return uste.err -//} -// -//func readTransfer(resp *Response, hyperResp *hyper.Response) (err error) { -// //// TODO(spongehah) Replace header operations with using the textproto package -// //lengthSlice := resp.Header["content-length"] -// //if lengthSlice == nil { -// // resp.ContentLength = -1 -// //} else { -// // contentLength := resp.Header["content-length"][0] -// // length, err := strconv.Atoi(contentLength) -// // if err != nil { -// // return err -// // } -// // resp.ContentLength = int64(length) -// //} -// -// t := &transferReader{ -// Header: resp.Header, -// StatusCode: resp.StatusCode, -// RequestMethod: resp.Request.Method, -// ProtoMajor: resp.ProtoMajor, -// ProtoMinor: resp.ProtoMinor, -// } -// -// // Transfer-Encoding: chunked, and overriding Content-Length. -// if err = t.parseTransferEncoding(); err != nil { -// return err -// } -// -// realLength, err := fixLength(true, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) -// if err != nil { -// return err -// } -// if t.RequestMethod == "HEAD" { -// if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { -// return err -// } else { -// t.ContentLength = n -// } -// } else { -// t.ContentLength = realLength -// } -// -// // Trailer -// t.Trailer, err = fixTrailer(t.Header, t.Chunked) -// -// // If there is no Content-Length or chunked Transfer-Encoding on a *Response -// // and the status is not 1xx, 204 or 304, then the body is unbounded. -// // See RFC 7230, section 3.3. -// if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { -// // Unbounded body. -// t.Close = true -// } -// -// return nil -//} -// -//// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. -//func (t *transferReader) parseTransferEncoding() error { -// raw, present := t.Header["Transfer-Encoding"] -// if !present { -// return nil -// } -// delete(t.Header, "Transfer-Encoding") -// -// // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. -// if !t.protoAtLeast(1, 1) { -// return nil -// } -// -// // Like nginx, we only support a single Transfer-Encoding header field, and -// // only if set to "chunked". This is one of the most security sensitive -// // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it -// // strict and simple. -// if len(raw) != 1 { -// return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} -// } -// if !equalFold(raw[0], "chunked") { -// return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} -// } -// -// // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field -// // in any message that contains a Transfer-Encoding header field." -// // -// // but also: "If a message is received with both a Transfer-Encoding and a -// // Content-Length header field, the Transfer-Encoding overrides the -// // Content-Length. Such a message might indicate an attempt to perform -// // request smuggling (Section 9.5) or response splitting (Section 9.4) and -// // ought to be handled as an error. A sender MUST remove the received -// // Content-Length field prior to forwarding such a message downstream." -// // -// // Reportedly, these appear in the wild. -// delete(t.Header, "Content-Length") -// -// t.Chunked = true -// return nil -//} -// -//func (t *transferReader) protoAtLeast(m, n int) bool { -// return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) -//} -// -//// equalFold is strings.EqualFold, ASCII only. It reports whether s and t -//// are equal, ASCII-case-insensitively. -//func equalFold(s, t string) bool { -// if len(s) != len(t) { -// return false -// } -// for i := 0; i < len(s); i++ { -// if lower(s[i]) != lower(t[i]) { -// return false -// } -// } -// return true -//} -// -//// Determine the expected body length, using RFC 7230 Section 3.3. This -//// function is not a method, because ultimately it should be shared by -//// ReadResponse and ReadRequest. -//func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { -// isRequest := !isResponse -// contentLens := header["Content-Length"] -// -// // Hardening against HTTP request smuggling -// if len(contentLens) > 1 { -// // Per RFC 7230 Section 3.3.2, prevent multiple -// // Content-Length headers if they differ in value. -// // If there are dups of the value, remove the dups. -// // See Issue 16490. -// first := textproto.TrimString(contentLens[0]) -// for _, ct := range contentLens[1:] { -// if first != textproto.TrimString(ct) { -// return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) -// } -// } -// -// // deduplicate Content-Length -// header.Del("Content-Length") -// header.Add("Content-Length", first) -// -// contentLens = header["Content-Length"] -// } -// -// // Logic based on response type or status -// if isResponse && noResponseBodyExpected(requestMethod) { -// return 0, nil -// } -// if status/100 == 1 { -// return 0, nil -// } -// switch status { -// case 204, 304: -// return 0, nil -// } -// -// // Logic based on Transfer-Encoding -// if chunked { -// return -1, nil -// } -// -// // Logic based on Content-Length -// var cl string -// if len(contentLens) == 1 { -// cl = textproto.TrimString(contentLens[0]) -// } -// if cl != "" { -// n, err := parseContentLength(cl) -// if err != nil { -// return -1, err -// } -// return n, nil -// } -// header.Del("Content-Length") -// -// if isRequest { -// // RFC 7230 neither explicitly permits nor forbids an -// // entity-body on a GET request so we permit one if -// // declared, but we default to 0 here (not -1 below) -// // if there's no mention of a body. -// // Likewise, all other request methods are assumed to have -// // no body if neither Transfer-Encoding chunked nor a -// // Content-Length are set. -// return 0, nil -// } -// -// // Body-EOF logic based on other methods (like closing, or chunked coding) -// return -1, nil -//} -// -//// parseContentLength trims whitespace from s and returns -1 if no value -//// is set, or the value if it's >= 0. -//func parseContentLength(cl string) (int64, error) { -// cl = textproto.TrimString(cl) -// if cl == "" { -// return -1, nil -// } -// n, err := strconv.ParseUint(cl, 10, 63) -// if err != nil { -// return 0, badStringError("bad Content-Length", cl) -// } -// return int64(n), nil -// -//} -// -//// Parse the trailer header. -//func fixTrailer(header Header, chunked bool) (Header, error) { -// vv, ok := header["Trailer"] -// if !ok { -// return nil, nil -// } -// if !chunked { -// // Trailer and no chunking: -// // this is an invalid use case for trailer header. -// // Nevertheless, no error will be returned and we -// // let users decide if this is a valid HTTP message. -// // The Trailer header will be kept in Response.Header -// // but not populate Response.Trailer. -// // See issue #27197. -// return nil, nil -// } -// header.Del("Trailer") -// -// trailer := make(Header) -// var err error -// for _, v := range vv { -// foreachHeaderElement(v, func(key string) { -// key = CanonicalHeaderKey(key) -// switch key { -// case "Transfer-Encoding", "Trailer", "Content-Length": -// if err == nil { -// err = badStringError("bad trailer key", key) -// return -// } -// } -// trailer[key] = nil -// }) -// } -// if err != nil { -// return nil, err -// } -// if len(trailer) == 0 { -// return nil, nil -// } -// return trailer, nil -//} -// -//// splitTwoDigitNumber splits a two-digit number into two digits. + +import ( + "fmt" + "io" + "net/textproto" + "strconv" + "strings" + "unicode/utf8" +) + +type transferReader struct { + // Input + Header Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + Chunked bool + Close bool + Trailer Header +} + +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +func readTransfer(msg any) (err error) { + t := &transferReader{RequestMethod: "GET"} + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true) + isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } + case *Request: + t.Header = rr.Header + t.RequestMethod = rr.Method + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + t.Close = rr.Close + default: + panic("unexpected type") + } + + // Transfer-Encoding: chunked, and overriding Content-Length. + if err = t.parseTransferEncoding(); err != nil { + return err + } + + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) + if err != nil { + return err + } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.Chunked) + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC 7230, section 3.3. + switch msg.(type) { + case *Response: + if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + //switch { + //case t.Chunked: + // if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + // t.Body = NoBody + // } else { + // t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + // } + //case realLength == 0: + // t.Body = NoBody + //case realLength > 0: + // t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + //default: + // // realLength < 0, i.e. "Content-Length" not mentioned in header + // if t.Close { + // // Close semantics (i.e. HTTP/1.0) + // t.Body = &body{src: r, closing: t.Close} + // } else { + // // Persistent connection (i.e. HTTP/1.1) + // t.Body = NoBody + // } + //} + + // Unify output + switch rr := msg.(type) { + case *Request: + //rr.Body = t.Body + //rr.ContentLength = t.ContentLength + //if t.Chunked { + // rr.TransferEncoding = []string{"chunked"} + //} + rr.Close = t.Close + //rr.Trailer = t.Trailer + case *Response: + //rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} + } + if !equalFold(raw[0], "chunked") { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil +} + +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +} + +// equalFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func equalFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// Determine the expected body length, using RFC 7230 Section 3.3. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { + isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := textproto.TrimString(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != textproto.TrimString(ct) { + return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + + // Logic based on response type or status + if isResponse && noResponseBodyExpected(requestMethod) { + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked { + return -1, nil + } + + // Logic based on Content-Length + var cl string + if len(contentLens) == 1 { + cl = textproto.TrimString(contentLens[0]) + } + if cl != "" { + n, err := parseContentLength(cl) + if err != nil { + return -1, err + } + return n, nil + } + header.Del("Content-Length") + + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = textproto.TrimString(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return 0, badStringError("bad Content-Length", cl) + } + return int64(n), nil + +} + +// Parse the trailer header. +func fixTrailer(header Header, chunked bool) (Header, error) { + vv, ok := header["Trailer"] + if !ok { + return nil, nil + } + if !chunked { + // Trailer and no chunking: + // this is an invalid use case for trailer header. + // Nevertheless, no error will be returned and we + // let users decide if this is a valid HTTP message. + // The Trailer header will be kept in Response.Header + // but not populate Response.Trailer. + // See issue #27197. + return nil, nil + } + header.Del("Trailer") + + trailer := make(Header) + var err error + for _, v := range vv { + foreachHeaderElement(v, func(key string) { + key = CanonicalHeaderKey(key) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + if err == nil { + err = badStringError("bad trailer key", key) + return + } + } + trailer[key] = nil + }) + } + if err != nil { + return nil, err + } + if len(trailer) == 0 { + return nil, nil + } + return trailer, nil +} + +// splitTwoDigitNumber splits a two-digit number into two digits. func splitTwoDigitNumber(num int) (int, int) { tens := num / 10 ones := num % 10 return tens, ones } -// -//// lower returns the ASCII lowercase version of b. -//func lower(b byte) byte { -// if 'A' <= b && b <= 'Z' { -// return b + ('a' - 'A') -// } -// return b -//} -// -//// foreachHeaderElement splits v according to the "#rule" construction -//// in RFC 7230 section 7 and calls fn for each non-empty element. -//func foreachHeaderElement(v string, fn func(string)) { -// v = textproto.TrimString(v) -// if v == "" { -// return -// } -// if !strings.Contains(v, ",") { -// fn(v) -// return -// } -// for _, f := range strings.Split(v, ",") { -// if f = textproto.TrimString(f); f != "" { -// fn(f) -// } -// } -//} -// -//func noResponseBodyExpected(requestMethod string) bool { -// return requestMethod == "HEAD" -//} -// -//func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } -// -//// bodyAllowedForStatus reports whether a given response status code -//// permits a body. See RFC 7230, section 3.3. -//func bodyAllowedForStatus(status int) bool { -// switch { -// case status >= 100 && status <= 199: -// return false -// case status == 204: -// return false -// case status == 304: -// return false -// } -// return true -//} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +func noResponseBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers. +func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { + if major < 1 { + return true + } + + conv := header["Connection"] + hasClose := HeaderValuesContainsToken(conv, "close") + if major == 1 && minor == 0 { + return hasClose || !HeaderValuesContainsToken(conv, "keep-alive") + } + + if hasClose && removeCloseHeader { + header.Del("Connection") + } + + return hasClose +} + +// HeaderValuesContainsToken reports whether any string in values +// contains the provided token, ASCII case-insensitively. +func HeaderValuesContainsToken(values []string, token string) bool { + for _, v := range values { + if headerValueContainsToken(v, token) { + return true + } + } + return false +} + +// headerValueContainsToken reports whether v (assumed to be a +// 0#element, in the ABNF extension described in RFC 7230 section 7) +// contains token amongst its comma-separated tokens, ASCII +// case-insensitively. +func headerValueContainsToken(v string, token string) bool { + for comma := strings.IndexByte(v, ','); comma != -1; comma = strings.IndexByte(v, ',') { + if tokenEqual(trimOWS(v[:comma]), token) { + return true + } + v = v[comma+1:] + } + return tokenEqual(trimOWS(v), token) +} + +// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. +func tokenEqual(t1, t2 string) bool { + if len(t1) != len(t2) { + return false + } + for i, b := range t1 { + if b >= utf8.RuneSelf { + // No UTF-8 or non-ASCII allowed in tokens. + return false + } + if lowerASCII(byte(b)) != lowerASCII(t2[i]) { + return false + } + } + return true +} + +// trimOWS returns x with all optional whitespace removes from the +// beginning and end. +func trimOWS(x string) string { + // TODO: consider using strings.Trim(x, " \t") instead, + // if and when it's fast enough. See issue 10292. + // But this ASCII-only code will probably always beat UTF-8 + // aware code. + for len(x) > 0 && isOWS(x[0]) { + x = x[1:] + } + for len(x) > 0 && isOWS(x[len(x)-1]) { + x = x[:len(x)-1] + } + return x +} + +// lowerASCII returns the ASCII lowercase version of b. +func lowerASCII(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// isOWS reports whether b is an optional whitespace byte, as defined +// by RFC 7230 section 3.2.3. +func isOWS(b byte) bool { return b == ' ' || b == '\t' } diff --git a/x/http/transport.go b/x/http/transport.go index b9f845c..2dad490 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -228,11 +228,6 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Poll all ready tasks and act on them... rc := <-pc.reqch // blocking alive := true - resp := &Response{ - Request: rc.req, - Header: make(Header), - Trailer: make(Header), - } var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { @@ -263,8 +258,17 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { client := (*hyper.ClientConn)(task.Value()) task.Free() + // Prepare the hyper.Request + hyperReq, err := newHyperRequest(rc.req) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + // Send it! - sendTask := client.Send(rc.req.Req) + sendTask := client.Send(hyperReq) SetTaskId(sendTask, ReceiveResp) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { @@ -289,14 +293,13 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { hyperResp := (*hyper.Response)(task.Value()) task.Free() - readResponseLineAndHeader(resp, hyperResp) - //err = readTransfer(resp, hyperResp) - //if err != nil { - // rc.ch <- responseAndError{err: err} - // // Free the resources - // FreeResources(task, respBody, bodyWriter, exec, pc, rc) - // return - //} + resp, err := ReadResponse(hyperResp, rc.req) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } respBody = hyperResp.Body() resp.Body, bodyWriter = io.Pipe() @@ -395,6 +398,8 @@ func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { conn := (*ConnData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + //base := make([]byte, suggestedSize) + //conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Pointer(&base[0])), c.Uint(suggestedSize)) conn.ReadBufFilled = 0 } *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) From ba2a9d098f61d4f2a5fe85c712497cf57fbeb8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 16 Aug 2024 18:09:37 +0800 Subject: [PATCH 10/21] WIP(x/http/client): Implement http.Post() and redirection logic --- go.mod | 7 +- go.sum | 8 +- x/http/_demo/post/post.go | 4 +- x/http/_demo/redirect/redirect.go | 26 ++ x/http/_demo/server/redirectServer.go | 26 ++ x/http/_demo/timeout/timeout.go | 4 +- x/http/client.go | 586 ++++++++++++++++++++++++-- x/http/clone.go | 11 + x/http/cookie.go | 232 ++++++++++ x/http/header.go | 71 +++- x/http/http.go | 27 ++ x/http/jar.go | 27 ++ x/http/request.go | 286 ++++++++++--- x/http/response.go | 19 +- x/http/transfer.go | 35 +- x/http/transport.go | 180 +++++--- x/http/util.go | 146 +++++++ 17 files changed, 1482 insertions(+), 213 deletions(-) create mode 100644 x/http/_demo/redirect/redirect.go create mode 100644 x/http/_demo/server/redirectServer.go create mode 100644 x/http/clone.go create mode 100644 x/http/cookie.go create mode 100644 x/http/http.go create mode 100644 x/http/jar.go create mode 100644 x/http/util.go diff --git a/go.mod b/go.mod index 4082df2..f961f75 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,9 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 +require ( + github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 + golang.org/x/net v0.28.0 +) + +require golang.org/x/text v0.17.0 // indirect diff --git a/go.sum b/go.sum index e3abd53..4c64063 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ -github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 h1:VIJ38bCFRIIr62YXyRKkxy6GXYVA6R3xqAb0HkcoUgw= -github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 h1:fqqbWhWaoseSplLJF8OTkNGl4Kruqm1wQWT/Yooq6E4= +github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= diff --git a/x/http/_demo/post/post.go b/x/http/_demo/post/post.go index a86e805..4958a8e 100644 --- a/x/http/_demo/post/post.go +++ b/x/http/_demo/post/post.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "fmt" "io" @@ -8,7 +9,8 @@ import ( ) func main() { - resp, err := http.Post("https://jsonplaceholder.typicode.com/posts", "application/json; charset=UTF-8", nil) + data := []byte(`{"id":1,"title":"foo","body":"bar","userId":"1"}`) + resp, err := http.Post("https://jsonplaceholder.typicode.com/posts", "application/json; charset=UTF-8", bytes.NewBuffer(data)) if err != nil { fmt.Println(err) return diff --git a/x/http/_demo/redirect/redirect.go b/x/http/_demo/redirect/redirect.go new file mode 100644 index 0000000..48465b7 --- /dev/null +++ b/x/http/_demo/redirect/redirect.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + resp, err := http.Get("http://localhost:8080") // Start "../server/redirectServer.go" before running + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + fmt.Println(resp.Proto) + resp.PrintHeaders() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/_demo/server/redirectServer.go b/x/http/_demo/server/redirectServer.go new file mode 100644 index 0000000..a6830af --- /dev/null +++ b/x/http/_demo/server/redirectServer.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "log" + "net/http" +) + +func main() { + http.HandleFunc("/", handleInitialRequest) + http.HandleFunc("/redirect", handleRedirectRequest) + + fmt.Println("Server is running on http://localhost:8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func handleInitialRequest(w http.ResponseWriter, r *http.Request) { + log.Println("Received initial request, redirecting...") + http.Redirect(w, r, "/redirect", http.StatusSeeOther) +} + +func handleRedirectRequest(w http.ResponseWriter, r *http.Request) { + log.Println("Received redirect request, sending response...") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "Hello redirect") +} diff --git a/x/http/_demo/timeout/timeout.go b/x/http/_demo/timeout/timeout.go index 42f8bf8..6eece04 100644 --- a/x/http/_demo/timeout/timeout.go +++ b/x/http/_demo/timeout/timeout.go @@ -10,8 +10,8 @@ import ( func main() { client := &http.Client{ - Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - //Timeout: time.Second * 5, + //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + Timeout: time.Second * 5, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { diff --git a/x/http/client.go b/x/http/client.go index 72e9688..31362a9 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,16 +1,25 @@ package http import ( + "context" + "encoding/base64" "errors" + "fmt" "io" + "log" "net/url" + "sort" "strings" + "sync" + "sync/atomic" "time" ) type Client struct { - Transport RoundTripper - Timeout time.Duration + Transport RoundTripper + CheckRedirect func(req *Request, via []*Request) error + Jar CookieJar + Timeout time.Duration } var DefaultClient = &Client{} @@ -38,6 +47,14 @@ func (c *Client) Get(url string) (*Response, error) { return c.Do(req) } +func alwaysFalse() bool { return false } + +// ErrUseLastResponse can be returned by Client.CheckRedirect hooks to +// control how redirects are processed. If returned, the next request +// is not sent and the most recent response is returned with its body +// unclosed. +var ErrUseLastResponse = errors.New("net/http: use last response") + func Post(url, contentType string, body io.Reader) (resp *Response, err error) { return DefaultClient.Post(url, contentType, body) } @@ -70,15 +87,15 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { } } var ( - //deadline = c.deadline() - reqs []*Request - resp *Response - //copyHeaders = c.makeHeadersCopier(req) + deadline = c.deadline() + reqs []*Request + resp *Response + copyHeaders = c.makeHeadersCopier(req) reqBodyClosed = false // have we closed the current req.Body? // Redirect behavior: - //redirectMethod string - //includeBody bool + redirectMethod string + includeBody bool ) uerr := func(err error) error { // the body may have been closed already by c.send() @@ -98,42 +115,236 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { } } - // For all but the first request, create the next - // request hop and replace req. for { + // For all but the first request, create the next + // request hop and replace req. if len(reqs) > 0 { + loc := resp.Header.Get("Location") + if loc == "" { + // While most 3xx responses include a Location, it is not + // required and 3xx responses without a Location have been + // observed in the wild. See issues #17773 and #49281. + return resp, nil + } + u, err := req.URL.Parse(loc) + if err != nil { + resp.closeBody() + return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) + } + // TODO(spongehah) redirect: Why use host := "" + //host := "" + host := u.Host + + if req.Host != "" && req.Host != req.URL.Host { + // If the caller specified a custom Host header and the + // redirect location is relative, preserve the Host header + // through the redirect. See issue #22233. + if u, _ := url.Parse(loc); u != nil && !u.IsAbs() { + host = req.Host + } + } + ireq := reqs[0] + req = &Request{ + Method: redirectMethod, + Response: resp, + URL: u, + Header: make(Header), + Host: host, + Cancel: ireq.Cancel, + ctx: ireq.ctx, + } + if includeBody && ireq.GetBody != nil { + req.Body, err = ireq.GetBody() + if err != nil { + resp.closeBody() + return nil, uerr(err) + } + req.ContentLength = ireq.ContentLength + } + + // Copy original headers before setting the Referer, + // in case the user set Referer on their first request. + // If they really want to override, they can do it in + // their CheckRedirect func. + copyHeaders(req) + // Add the Referer header from the most recent + // request URL to the new one, if it's not https->http: + if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL, req.Header.Get("Referer")); ref != "" { + req.Header.Set("Referer", ref) + } + err = c.checkRedirect(req, reqs) + + // Sentinel error to let users select the + // previous response, without closing its + // body. See Issue 10069. + if err == ErrUseLastResponse { + return resp, nil + } + + // Close the previous response's body. But + // read at least some of the body so if it's + // small the underlying TCP connection will be + // re-used. No need to check for errors: if it + // fails, the Transport won't reuse it anyway. + const maxBodySlurpSize = 2 << 10 + if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { + io.CopyN(io.Discard, resp.Body, maxBodySlurpSize) + } + resp.Body.Close() + + if err != nil { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See https://golang.org/issue/3795 + // The resp.Body has already been closed. + ue := uerr(err) + ue.(*url.Error).URL = loc + return resp, ue + } } reqs = append(reqs, req) var err error - if resp, err = c.send(req, c.Timeout); err != nil { + var didTimeout func() bool + if resp, didTimeout, err = c.send(req, deadline); err != nil { // c.send() always closes req.Body reqBodyClosed = true + if !deadline.IsZero() && didTimeout() { + err = &httpError{ + err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } return nil, uerr(err) } var shouldRedirect bool - //redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) - _, shouldRedirect, _ = redirectBehavior(req.Method, resp, reqs[0]) + redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) if !shouldRedirect { return resp, nil - } else { - // TODO(spongehah) - return nil, errors.New("TODO: redirect not implemented") } req.closeBody() } } -func (c *Client) send(req *Request, timeout time.Duration) (*Response, error) { - return send(req, c.transport(), timeout) +// didTimeout is non-nil only if err != nil. +func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + // TODO(spongehah) cookie + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, didTimeout, err = send(req, c.transport(), deadline) + if err != nil { + return nil, didTimeout, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, nil, nil } -func send(req *Request, rt RoundTripper, timeout time.Duration) (resp *Response, err error) { - req.timeout = timeout - return rt.RoundTrip(req) +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. +func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + req := ireq // req is either the original request, or a modified fork + + if rt == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport") + } + + if req.URL == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: nil Request.URL") + } + + if req.RequestURI != "" { + req.closeBody() + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests") + } + + // forkReq forks req into a shallow clone of ireq the first + // time it's called. + forkReq := func() { + if ireq == req { + req = new(Request) + *req = *ireq // shallow clone + } + } + + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + forkReq() + req.Header = make(Header) + } + + if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { + username := u.Username() + password, _ := u.Password() + forkReq() + req.Header = cloneOrMakeHeader(ireq.Header) + req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + } + + if !deadline.IsZero() { + forkReq() + } + + // TODO(spongehah) timeout + //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + sub := deadline.Sub(time.Now()) + req.timeout = sub + resp, err = rt.RoundTrip(req) + if err != nil { + //stopTimer() + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + //if tlsErr, ok := err.(tls.RecordHeaderError); ok { + // // If we get a bad TLS record header, check to see if the + // // response looks like HTTP and give a more helpful error. + // // See golang.org/issue/11111. + // if string(tlsErr.RecordHeader[:]) == "HTTP/" { + // err = ErrSchemeMismatch + // } + //} + return nil, didTimeout, err + } + if resp == nil { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a nil *Response with a nil error", rt) + } + if resp.Body == nil { + // The documentation on the Body field says “The http Client and Transport + // guarantee that Body is always non-nil, even on responses without a body + // or responses with a zero-length body.” Unfortunately, we didn't document + // that same constraint for arbitrary RoundTripper implementations, and + // RoundTripper implementations in the wild (mostly in tests) assume that + // they can use a nil Body to mean an empty one (similar to Request.Body). + // (See https://golang.org/issue/38095.) + // + // If the ContentLength allows the Body to be empty, fill in an empty one + // here to ensure that it is non-nil. + if resp.ContentLength > 0 && req.Method != "HEAD" { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength) + } + resp.Body = io.NopCloser(strings.NewReader("")) + } + //if !deadline.IsZero() { + // resp.Body = &cancelTimerBody{ + // stop: stopTimer, + // rc: resp.Body, + // reqDidTimeout: didTimeout, + // } + //} + return resp, nil, nil } // redirectBehavior describes what should happen when the @@ -192,29 +403,330 @@ func urlErrorOp(method string) string { return method } -// ToLower returns the lowercase version of s if s is ASCII and printable. -func ToLower(s string) (lower string, ok bool) { - if !IsPrint(s) { - return "", false +func stripPassword(u *url.URL) string { + _, passSet := u.User.Password() + if passSet { + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) } - return strings.ToLower(s), true + return u.String() } -// IsPrint returns whether s is ASCII and printable according to -// https://tools.ietf.org/html/rfc20#section-4.2. -func IsPrint(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] < ' ' || s[i] > '~' { - return false +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} +} + +// cancelTimerBody is an io.ReadCloser that wraps rc with two features: +// 1. On Read error or close, the stop func is called. +// 2. On Read failure, if reqDidTimeout is true, the error is wrapped and +// marked as net.Error that hit its timeout. +type cancelTimerBody struct { + stop func() // stops the time.Timer waiting to cancel the request + rc io.ReadCloser + reqDidTimeout func() bool +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == nil { + return n, nil + } + if err == io.EOF { + return n, err + } + if b.reqDidTimeout() { + err = &httpError{ + err: err.Error() + " (Client.Timeout or context cancellation while reading body)", + timeout: true, } } + return n, err +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.stop() + return err +} + +// setRequestCancel sets req.Cancel and adds a deadline context to req +// if deadline is non-zero. The RoundTripper's type is used to +// determine whether the legacy CancelRequest behavior should be used. +// +// As background, there are three ways to cancel a request: +// First was Transport.CancelRequest. (deprecated) +// Second was Request.Cancel. +// Third was Request.Context. +// This function populates the second and third, and uses the first if it really needs to. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { + if deadline.IsZero() { + return nop, alwaysFalse + } + // TODO(spongehah) todo: map[string]github.com/goplus/llgo/x/http.RoundTripper + //knownTransport := knownRoundTripperImpl(rt, req) + oldCtx := req.Context() + + //if req.Cancel == nil && knownTransport { + if req.Cancel == nil { + // If they already had a Request.Context that's + // expiring sooner, do nothing: + if !timeBeforeContextDeadline(deadline, oldCtx) { + return nop, alwaysFalse + } + + var cancelCtx func() + req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) + return cancelCtx, func() bool { return time.Now().After(deadline) } + } + initialReqCancel := req.Cancel // the user's original Request.Cancel, if any + + var cancelCtx func() + if timeBeforeContextDeadline(deadline, oldCtx) { + req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) + } + + cancel := make(chan struct{}) + req.Cancel = cancel + + doCancel := func() { + // The second way in the func comment above: + close(cancel) + // The first way, used only for RoundTripper + // implementations written before Go 1.5 or Go 1.6. + type canceler interface{ CancelRequest(*Request) } + if v, ok := rt.(canceler); ok { + v.CancelRequest(req) + } + } + + stopTimerCh := make(chan struct{}) + var once sync.Once + stopTimer = func() { + once.Do(func() { + close(stopTimerCh) + if cancelCtx != nil { + cancelCtx() + } + }) + } + + timer := time.NewTimer(time.Until(deadline)) + var timedOut atomic.Bool + + go func() { + select { + case <-initialReqCancel: + doCancel() + timer.Stop() + case <-timer.C: + timedOut.Store(true) + doCancel() + case <-stopTimerCh: + timer.Stop() + } + }() + + return stopTimer, timedOut.Load +} + +// timeBeforeContextDeadline reports whether the non-zero Time t is +// before ctx's deadline, if any. If ctx does not have a deadline, it +// always reports true (the deadline is considered infinite). +func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { + d, ok := ctx.Deadline() + if !ok { + return true + } + return t.Before(d) +} + +/* +// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// maintained by the Go team and known to implement the latest +// optional semantics (notably contexts). The Request is used +// to check whether this particular request is using an alternate protocol, +// in which case we need to check the RoundTripper for that protocol. +func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { + switch t := rt.(type) { + case *Transport: + if altRT := t.alternateRoundTripper(req); altRT != nil { + return knownRoundTripperImpl(altRT, req) + } + return true + //case *http2Transport, http2noDialH2RoundTripper: + // return true + } + // There's a very minor chance of a false positive with this. + // Instead of detecting our golang.org/x/net/http2.Transport, + // it might detect a Transport type in a different http2 + // package. But I know of none, and the only problem would be + // some temporarily leaked goroutines if the transport didn't + // support contexts. So this is a good enough heuristic: + if reflect.TypeOf(rt).String() == "*http2.Transport" { + return true + } + return false +}*/ + +// makeHeadersCopier makes a function that copies headers from the +// initial Request, ireq. For every redirect, this function must be called +// so that it can copy headers into the upcoming Request. +func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) { + // The headers to copy are from the very initial request. + // We use a closured callback to keep a reference to these original headers. + var ( + ireqhdr = cloneOrMakeHeader(ireq.Header) + icookies map[string][]*Cookie + ) + if c.Jar != nil && ireq.Header.Get("Cookie") != "" { + icookies = make(map[string][]*Cookie) + for _, c := range ireq.Cookies() { + icookies[c.Name] = append(icookies[c.Name], c) + } + } + + preq := ireq // The previous request + return func(req *Request) { + // If Jar is present and there was some initial cookies provided + // via the request header, then we may need to alter the initial + // cookies as we follow redirects since each redirect may end up + // modifying a pre-existing cookie. + // + // Since cookies already set in the request header do not contain + // information about the original domain and path, the logic below + // assumes any new set cookies override the original cookie + // regardless of domain or path. + // + // See https://golang.org/issue/17494 + if c.Jar != nil && icookies != nil { + var changed bool + resp := req.Response // The response that caused the upcoming redirect + for _, c := range resp.Cookies() { + if _, ok := icookies[c.Name]; ok { + delete(icookies, c.Name) + changed = true + } + } + if changed { + ireqhdr.Del("Cookie") + var ss []string + for _, cs := range icookies { + for _, c := range cs { + ss = append(ss, c.Name+"="+c.Value) + } + } + sort.Strings(ss) // Ensure deterministic headers + ireqhdr.Set("Cookie", strings.Join(ss, "; ")) + } + } + + // Copy the initial request's Header values + // (at least the safe ones). + for k, vv := range ireqhdr { + if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) { + req.Header[k] = vv + } + } + + preq = req // Update previous Request with the current request + } +} + +func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { + switch CanonicalHeaderKey(headerKey) { + case "Authorization", "Www-Authenticate", "Cookie", "Cookie2": + // Permit sending auth/cookie headers from "foo.com" + // to "sub.foo.com". + + // Note that we don't send all cookies to subdomains + // automatically. This function is only used for + // Cookies set explicitly on the initial outgoing + // client request. Cookies automatically added via the + // CookieJar mechanism continue to follow each + // cookie's scope as set by Set-Cookie. But for + // outgoing requests with the Cookie header set + // directly, we don't know their scope, so we assume + // it's for *.domain.com. + + ihost := idnaASCIIFromURL(initial) + dhost := idnaASCIIFromURL(dest) + return isDomainOrSubdomain(dhost, ihost) + } + // All other headers are copied: return true } -func stripPassword(u *url.URL) string { - _, passSet := u.User.Password() - if passSet { - return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) +// isDomainOrSubdomain reports whether sub is a subdomain (or exact +// match) of the parent domain. +// +// Both domains must already be in canonical form. +func isDomainOrSubdomain(sub, parent string) bool { + if sub == parent { + return true } - return u.String() + // If sub is "foo.example.com" and parent is "example.com", + // that means sub must end in "."+parent. + // Do it without allocating. + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} + +// refererForURL returns a referer without any authentication info or +// an empty string if lastReq scheme is https and newReq scheme is http. +// If the referer was explicitly set, then it will continue to be used. +func refererForURL(lastReq, newReq *url.URL, explicitRef string) string { + // https://tools.ietf.org/html/rfc7231#section-5.5.2 + // "Clients SHOULD NOT include a Referer header field in a + // (non-secure) HTTP request if the referring page was + // transferred with a secure protocol." + if lastReq.Scheme == "https" && newReq.Scheme == "http" { + return "" + } + if explicitRef != "" { + return explicitRef + } + + referer := lastReq.String() + if lastReq.User != nil { + // This is not very efficient, but is the best we can + // do without: + // - introducing a new method on URL + // - creating a race condition + // - copying the URL struct manually, which would cause + // maintenance problems down the line + auth := lastReq.User.String() + "@" + referer = strings.Replace(referer, auth, "", 1) + } + return referer +} + +// checkRedirect calls either the user's configured CheckRedirect +// function, or the default. +func (c *Client) checkRedirect(req *Request, via []*Request) error { + fn := c.CheckRedirect + if fn == nil { + fn = defaultCheckRedirect + } + return fn(req, via) +} + +func defaultCheckRedirect(req *Request, via []*Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil } diff --git a/x/http/clone.go b/x/http/clone.go new file mode 100644 index 0000000..ff67949 --- /dev/null +++ b/x/http/clone.go @@ -0,0 +1,11 @@ +package http + +// cloneOrMakeHeader invokes Header.Clone but if the +// result is nil, it'll instead make and return a non-nil Header. +func cloneOrMakeHeader(hdr Header) Header { + clone := hdr.Clone() + if clone == nil { + clone = make(Header) + } + return clone +} diff --git a/x/http/cookie.go b/x/http/cookie.go new file mode 100644 index 0000000..4b7175c --- /dev/null +++ b/x/http/cookie.go @@ -0,0 +1,232 @@ +package http + +import ( + "log" + "net/textproto" + "strconv" + "strings" + "time" +) + +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. +// +// See https://tools.ietf.org/html/rfc6265 for details. +type Cookie struct { + Name string + Value string + + Path string // optional + Domain string // optional + Expires time.Time // optional + RawExpires string // for reading cookies only + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + SameSite SameSite + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// SameSite allows a server to define a cookie attribute making it impossible for +// the browser to send this cookie along with cross-site requests. The main +// goal is to mitigate the risk of cross-origin information leakage, and provide +// some protection against cross-site request forgery attacks. +// +// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + 1 + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeCookieName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +// sanitizeCookieValue produces a suitable cookie-value from v. +// https://tools.ietf.org/html/rfc6265#section-4.1.1 +// +// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE ) +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value if and only if v contains +// commas or spaces. +// See https://golang.org/issue/7243 for the discussion. +func sanitizeCookieValue(v string) string { + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if strings.ContainsAny(v, " ,") { + return `"` + v + `"` + } + return v +} + +func validCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + buf := make([]byte, 0, len(v)) + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + buf = append(buf, b) + } + } + return string(buf) +} + +// readSetCookies parses all "Set-Cookie" values from +// the header h and returns the successfully parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookieCount := len(h["Set-Cookie"]) + if cookieCount == 0 { + return []*Cookie{} + } + cookies := make([]*Cookie, 0, cookieCount) + for _, line := range h["Set-Cookie"] { + parts := strings.Split(textproto.TrimString(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = textproto.TrimString(parts[0]) + name, value, ok := strings.Cut(parts[0], "=") + if !ok { + continue + } + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + value, ok = parseCookieValue(value, true) + if !ok { + continue + } + c := &Cookie{ + Name: name, + Value: value, + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = textproto.TrimString(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val, _ := strings.Cut(parts[i], "=") + lowerAttr, isASCII := ToLower(attr) + if !isASCII { + continue + } + val, ok = parseCookieValue(val, false) + if !ok { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + + switch lowerAttr { + case "samesite": + lowerVal, ascii := ToLower(val) + if !ascii { + c.SameSite = SameSiteDefaultMode + continue + } + switch lowerVal { + case "lax": + c.SameSite = SameSiteLaxMode + case "strict": + c.SameSite = SameSiteStrictMode + case "none": + c.SameSite = SameSiteNoneMode + default: + c.SameSite = SameSiteDefaultMode + } + continue + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "domain": + c.Domain = val + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs != 0 && val[0] == '0' { + break + } + if secs <= 0 { + secs = -1 + } + c.MaxAge = secs + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } + } + c.Expires = exptime.UTC() + continue + case "path": + c.Path = val + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + return cookies +} + +func isCookieNameValid(raw string) bool { + if raw == "" { + return false + } + return strings.IndexFunc(raw, isNotToken) < 0 +} + +func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { + // Strip the quotes, if present. + if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } + for i := 0; i < len(raw); i++ { + if !validCookieValueByte(raw[i]) { + return "", false + } + } + return raw, true +} diff --git a/x/http/header.go b/x/http/header.go index 076db0f..6515c48 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -81,8 +81,75 @@ func (h Header) Del(key string) { // returned without modifications. func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } -// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console -func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { +// Clone returns a copy of h or nil if h is nil. +func (h Header) Clone() Header { + if h == nil { + return nil + } + + // Find total number of values. + nv := 0 + for _, vv := range h { + nv += len(vv) + } + sv := make([]string, nv) // shared backing array for headers' values + h2 := make(Header, len(h)) + for k, vv := range h { + if vv == nil { + // Preserve nil values. ReverseProxy distinguishes + // between nil and zero-length header values. + h2[k] = nil + continue + } + n := copy(sv, vv) + h2[k] = sv[:n:n] + sv = sv[n:] + } + return h2 +} + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} + +// appendToResponseHeader (HeadersForEachCallback) prints each header to the console +func appendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { resp := (*Response)(userdata) nameStr := c.GoString((*int8)(c.Pointer(name)), nameLen) valueStr := c.GoString((*int8)(c.Pointer(value)), valueLen) diff --git a/x/http/http.go b/x/http/http.go new file mode 100644 index 0000000..f668906 --- /dev/null +++ b/x/http/http.go @@ -0,0 +1,27 @@ +package http + +import "strings" + +// splitTwoDigitNumber splits a two-digit number into two digits. +func splitTwoDigitNumber(num int) (int, int) { + tens := num / 10 + ones := num % 10 + return tens, ones +} + +func isNotToken(r rune) bool { + return !IsTokenRune(r) +} + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } diff --git a/x/http/jar.go b/x/http/jar.go new file mode 100644 index 0000000..5c3de0d --- /dev/null +++ b/x/http/jar.go @@ -0,0 +1,27 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" +) + +// A CookieJar manages storage and use of cookies in HTTP requests. +// +// Implementations of CookieJar must be safe for concurrent use by multiple +// goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. +type CookieJar interface { + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. + SetCookies(u *url.URL, cookies []*Cookie) + + // Cookies returns the cookies to send in a request for the given URL. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. + Cookies(u *url.URL) []*Cookie +} diff --git a/x/http/request.go b/x/http/request.go index c84cd9b..f6e6f16 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -1,11 +1,17 @@ package http import ( + "bytes" + "context" "fmt" "io" + "net/textproto" "net/url" + "strings" "time" + "golang.org/x/net/idna" + "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" @@ -24,79 +30,124 @@ type Request struct { TransferEncoding []string Close bool Host string - timeout time.Duration -} - -type postBody struct { - data []byte - len uintptr - readLen uintptr + //Form url.Values + //PostForm url.Values + //MultipartForm *multipart.Form + Trailer Header + RemoteAddr string + RequestURI string + //TLS *tls.ConnectionState + Cancel <-chan struct{} + Response *Response + timeout time.Duration + ctx context.Context } -type uploadBody struct { - fd c.Int - buf []byte - len uintptr -} - -var DefaultChunkSize uintptr = 8192 +var defaultChunkSize uintptr = 8192 func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + if method == "" { + // We document that "" means "GET" for Request.Method, and people have + // relied on that from NewRequest, so keep that working. + // We still enforce validMethod for non-empty methods. + method = "GET" + } + if !validMethod(method) { + return nil, fmt.Errorf("net/http: invalid method %q", method) + } + //if ctx == nil { + // return nil, errors.New("net/http: nil Context") + //} u, err := url.Parse(urlStr) if err != nil { return nil, err } - //rc, ok := body.(io.ReadCloser) - //if !ok && body != nil { - // rc = io.NopCloser(body) - //} - request := &Request{ + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = io.NopCloser(body) + } + // The host's colon:port should be normalized. See Issue 14836. + u.Host = removeEmptyPort(u.Host) + req := &Request{ + //ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(Header), + Body: rc, Host: u.Host, - //Body: rc, - timeout: 0, } - request.Header.Set("Host", request.Host) + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + buf := v.Bytes() + req.GetBody = func() (io.ReadCloser, error) { + r := bytes.NewReader(buf) + return io.NopCloser(r), nil + } + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return io.NopCloser(&r), nil + } + case *strings.Reader: + req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return io.NopCloser(&r), nil + } + default: + // This is where we'd set it to -1 (at least + // if body != NoBody) to mean unknown, but + // that broke people during the Go 1.8 testing + // period. People depend on it being 0 I + // guess. Maybe retry later. See Issue 18117. + } + // For client requests, Request.ContentLength of 0 + // means either actually 0, or unknown. The only way + // to explicitly say that the ContentLength is zero is + // to set the Body to nil. But turns out too much code + // depends on NewRequest returning a non-nil Body, + // so we use a well-known ReadCloser variable instead + // and have the http package also treat that sentinel + // variable to mean explicitly zero. + if req.GetBody != nil && req.ContentLength == 0 { + req.Body = NoBody + req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil } + } + } - return request, nil + return req, nil } -func PrintInformational(userdata c.Pointer, resp *hyper.Response) { +func printInformational(userdata c.Pointer, resp *hyper.Response) { status := resp.Status() fmt.Println("Informational (1xx): ", status) } -func SetPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - //upload := (*uploadBody)(userdata) - //res := os.Read(upload.fd, c.Pointer(&upload.buf[0]), upload.len) - //if res > 0 { - // *chunk = hyper.CopyBuf(&upload.buf[0], uintptr(res)) - // return hyper.PollReady - //} - //if res == 0 { - // *chunk = nil - // os.Close(upload.fd) - // return hyper.PollReady - //} - body := (*postBody)(userdata) - if body.len > 0 { - if body.len > DefaultChunkSize { - *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) - body.readLen += DefaultChunkSize - body.len -= DefaultChunkSize - } else { - *chunk = hyper.CopyBuf(&body.data[body.readLen], body.len) - body.readLen += body.len - body.len = 0 +func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + req := (*Request)(userdata) + buffer := make([]byte, defaultChunkSize) + n, err := req.Body.Read(buffer) + if err != nil { + if err == io.EOF { + *chunk = nil + return hyper.PollReady } + fmt.Println("error reading upload file: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&buffer[0], uintptr(n)) return hyper.PollReady } - if body.len == 0 { + if n == 0 { *chunk = nil return hyper.PollReady } @@ -107,7 +158,7 @@ func SetPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In func newHyperRequest(req *Request) (*hyper.Request, error) { host := req.Host - uri := req.URL.Path + uri := req.URL.RequestURI() method := req.Method // Prepare the request hyperReq := hyper.NewRequest() @@ -124,27 +175,13 @@ func newHyperRequest(req *Request) (*hyper.Request, error) { return nil, fmt.Errorf("error setting header: Host: %s\n", host) } - if method == "POST" { - //var upload uploadBody - //upload.fd = os.Open(c.Str("/Users/spongehah/go/src/llgo/x/http/_demo/post/example.txt"), os.O_RDONLY) - //if upload.fd < 0 { - // return nil, fmt.Errorf("error opening file to upload: %s\n", c.GoString(c.Strerror(os.Errno))) - //} - //upload.len = 8192 - //upload.buf = make([]byte, upload.len) + if method == "POST" && req.Body != nil { req.Header.Set("expect", "100-continue") - hyperReq.OnInformational(PrintInformational, nil) - postData := []byte(`{"id":1,"title":"foo","body":"bar","userId":"1"}`) - - reqBody := &postBody{ - data: postData, - len: uintptr(len(postData)), - } + hyperReq.OnInformational(printInformational, nil) hyperReqBody := hyper.NewBody() - hyperReqBody.SetUserdata(c.Pointer(reqBody)) - //hyperReqBody.SetUserdata(c.Pointer(&upload)) - hyperReqBody.SetDataFunc(SetPostData) + hyperReqBody.SetUserdata(c.Pointer(req)) + hyperReqBody.SetDataFunc(setPostData) hyperReq.SetBody(hyperReqBody) } @@ -185,3 +222,120 @@ func (r *Request) closeBody() error { } return r.Body.Close() } + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// Context returns the request's context. To change the context, use +// Clone or WithContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +// +// For outgoing client requests, the context controls cancellation. +// +// For incoming server requests, the context is canceled when the +// client's connection closes, the request is canceled (with HTTP/2), +// or when the ServeHTTP method returns. +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That +// means all cookies, if any, are written into the same line, +// separated by semicolon. +// AddCookie only sanitizes c's name and value, and does not sanitize +// a Cookie header already present in the request. +func (r *Request) AddCookie(c *Cookie) { + s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + if c := r.Header.Get("Cookie"); c != "" { + r.Header.Set("Cookie", c+"; "+s) + } else { + r.Header.Set("Cookie", s) + } +} + +// requiresHTTP1 reports whether this request requires being sent on +// an HTTP/1 connection. +func (r *Request) requiresHTTP1() bool { + return hasToken(r.Header.Get("Connection"), "upgrade") && + EqualFold(r.Header.Get("Upgrade"), "websocket") +} + +// Cookies parses and returns the HTTP cookies sent with the request. +func (r *Request) Cookies() []*Cookie { + return readCookies(r.Header, "") +} + +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned. +func readCookies(h Header, filter string) []*Cookie { + lines := h["Cookie"] + if len(lines) == 0 { + return []*Cookie{} + } + + cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";")) + for _, line := range lines { + line = textproto.TrimString(line) + + var part string + for len(line) > 0 { // continue since we have rest + part, line, _ = strings.Cut(line, ";") + part = textproto.TrimString(part) + if part == "" { + continue + } + name, val, _ := strings.Cut(part, "=") + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { + continue + } + val, ok := parseCookieValue(val, true) + if !ok { + continue + } + cookies = append(cookies, &Cookie{Name: name, Value: val}) + } + } + return cookies +} + +func idnaASCII(v string) (string, error) { + // TODO: Consider removing this check after verifying performance is okay. + // Right now punycode verification, length checks, context checks, and the + // permissible character tests are all omitted. It also prevents the ToASCII + // call from salvaging an invalid IDN, when possible. As a result it may be + // possible to have two IDNs that appear identical to the user where the + // ASCII-only version causes an error downstream whereas the non-ASCII + // version does not. + // Note that for correct ASCII IDNs ToASCII will only do considerably more + // work, but it will not cause an allocation. + if Is(v) { + return v, nil + } + return idna.Lookup.ToASCII(v) +} diff --git a/x/http/response.go b/x/http/response.go index c99bade..174d2fc 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -20,15 +20,21 @@ type Response struct { ContentLength int64 TransferEncoding []string Close bool - Trailer Header - Request *Request + //Trailer Header + Request *Request +} + +func (r *Response) closeBody() { + if r.Body != nil { + r.Body.Close() + } } func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), - Trailer: make(Header), + //Trailer: make(Header), } readResponseLineAndHeader(resp, hyperResp) @@ -54,7 +60,7 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { resp.Proto = fmt.Sprintf("HTTP/%d.%d", resp.ProtoMajor, resp.ProtoMinor) headers := hyperResp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(resp)) + headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } // RFC 7234, section 5.4: Should treat @@ -71,3 +77,8 @@ func fixPragmaCacheControl(header Header) { } } } + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} diff --git a/x/http/transfer.go b/x/http/transfer.go index 70f082e..ac50296 100644 --- a/x/http/transfer.go +++ b/x/http/transfer.go @@ -94,7 +94,7 @@ func readTransfer(msg any) (err error) { } // Trailer - t.Trailer, err = fixTrailer(t.Header, t.Chunked) + //t.Trailer, err = fixTrailer(t.Header, t.Chunked) // If there is no Content-Length or chunked Transfer-Encoding on a *Response // and the status is not 1xx, 204 or 304, then the body is unbounded. @@ -148,7 +148,7 @@ func readTransfer(msg any) (err error) { rr.TransferEncoding = []string{"chunked"} } rr.Close = t.Close - rr.Trailer = t.Trailer + //rr.Trailer = t.Trailer } return nil @@ -174,7 +174,7 @@ func (t *transferReader) parseTransferEncoding() error { if len(raw) != 1 { return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} } - if !equalFold(raw[0], "chunked") { + if !EqualFold(raw[0], "chunked") { return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} } @@ -199,20 +199,6 @@ func (t *transferReader) protoAtLeast(m, n int) bool { return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) } -// equalFold is strings.EqualFold, ASCII only. It reports whether s and t -// are equal, ASCII-case-insensitively. -func equalFold(s, t string) bool { - if len(s) != len(t) { - return false - } - for i := 0; i < len(s); i++ { - if lower(s[i]) != lower(t[i]) { - return false - } - } - return true -} - // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. @@ -343,21 +329,6 @@ func fixTrailer(header Header, chunked bool) (Header, error) { return trailer, nil } -// splitTwoDigitNumber splits a two-digit number into two digits. -func splitTwoDigitNumber(num int) (int, int) { - tens := num / 10 - ones := num % 10 - return tens, ones -} - -// lower returns the ASCII lowercase version of b. -func lower(b byte) byte { - if 'A' <= b && b <= 'Z' { - return b + ('a' - 'A') - } - return b -} - // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. func foreachHeaderElement(v string, fn func(string)) { diff --git a/x/http/transport.go b/x/http/transport.go index 2dad490..f9bfa46 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -3,6 +3,8 @@ package http import ( "fmt" "io" + "net/url" + "sync/atomic" "unsafe" "github.com/goplus/llgo/c" @@ -12,7 +14,7 @@ import ( "github.com/goplus/llgoexamples/rust/hyper" ) -type ConnData struct { +type connData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect ReadBuf libuv.Buf @@ -24,20 +26,21 @@ type ConnData struct { } type Transport struct { + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme } -// TaskId The unique identifier of the next task polled from the executor -type TaskId c.Int +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int const ( - NotSet TaskId = iota - Send - ReceiveResp - ReceiveRespBody + notSet taskId = iota + sending + receiveResp + receiveRespBody ) const ( - DefaultHTTPPort = "80" + defaultHTTPPort = "80" ) var DefaultTransport RoundTripper = &Transport{} @@ -54,7 +57,7 @@ type persistConn struct { //nwrite int64 // bytes written //writech chan writeRequest // written by roundTrip; read by writeLoop //closech chan struct{} // closed when conn closed - conn *ConnData + conn *connData t *Transport reqch chan requestAndChan // written by roundTrip; read by readLoop cancelch chan freeChan @@ -82,7 +85,7 @@ type responseAndError struct { type connAndTimeoutChan struct { _ incomparable - conn *ConnData + conn *connData timeoutch chan struct{} } @@ -105,29 +108,29 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { - host := req.Host + host := req.URL.Hostname() port := req.URL.Port() if port == "" { // Hyper only supports http - port = DefaultHTTPPort + port = defaultHTTPPort } loop := libuv.DefaultLoop() //conn := (*ConnData)(c.Calloc(1, unsafe.Sizeof(ConnData{}))) - conn := new(ConnData) + conn := new(connData) if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } // If timeout is set, start the timer timeoutch := make(chan struct{}, 1) - if req.timeout != 0 { + if req.timeout > 0 { libuv.InitTimer(loop, &conn.TimeoutTimer) ct := &connAndTimeoutChan{ conn: conn, timeoutch: timeoutch, } (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) - conn.TimeoutTimer.Start(OnTimeout, uint64(req.timeout.Milliseconds()), 0) + conn.TimeoutTimer.Start(onTimeout, uint64(req.timeout.Milliseconds()), 0) } libuv.InitTcp(loop, &conn.TcpHandle) @@ -148,7 +151,7 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { //conn.ConnectReq.Data = c.Pointer(conn) (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) - status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) + status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { close(timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) @@ -209,7 +212,7 @@ func (pc *persistConn) roundTrip(req *Request) (*Response, error) { // It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Hookup the IO - hyperIo := NewIoWithConnReadWrite(pc.conn) + hyperIo := newIoWithConnReadWrite(pc.conn) // We need an executor generally to poll futures exec := hyper.NewExecutor() @@ -218,7 +221,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { opts.Exec(exec) handshakeTask := hyper.Handshake(hyperIo, opts) - SetTaskId(handshakeTask, Send) + setTaskId(handshakeTask, sending) // Let's wait for the handshake to finish... exec.Push(handshakeTask) @@ -241,13 +244,12 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { default: task := exec.Poll() if task == nil { - //break loop.Run(libuv.RUN_ONCE) continue } - switch (TaskId)(uintptr(task.Userdata())) { - case Send: - err := CheckTaskType(task, Send) + switch (taskId)(uintptr(task.Userdata())) { + case sending: + err := checkTaskType(task, sending) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources @@ -269,7 +271,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Send it! sendTask := client.Send(hyperReq) - SetTaskId(sendTask, ReceiveResp) + setTaskId(sendTask, receiveResp) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} @@ -280,8 +282,8 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // For this example, no longer need the client client.Free() - case ReceiveResp: - err := CheckTaskType(task, ReceiveResp) + case receiveResp: + err := checkTaskType(task, receiveResp) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources @@ -309,19 +311,19 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Response has been returned, stop the timer pc.conn.IsCompleted = 1 // Stop the timer - if rc.req.timeout != 0 { + if rc.req.timeout > 0 { pc.conn.TimeoutTimer.Stop() (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) } dataTask := respBody.Data() - SetTaskId(dataTask, ReceiveRespBody) + setTaskId(dataTask, receiveRespBody) exec.Push(dataTask) // No longer need the response hyperResp.Free() - case ReceiveRespBody: - err := CheckTaskType(task, ReceiveRespBody) + case receiveRespBody: + err := checkTaskType(task, receiveRespBody) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources @@ -350,7 +352,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { task.Free() dataTask := respBody.Data() - SetTaskId(dataTask, ReceiveRespBody) + setTaskId(dataTask, receiveRespBody) exec.Push(dataTask) break @@ -369,7 +371,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { FreeResources(task, respBody, bodyWriter, exec, pc, rc) alive = false - case NotSet: + case notSet: // A background task for hyper_client completed... task.Free() } @@ -378,24 +380,24 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} } -// OnConnect is the libuv callback for a successful connection -func OnConnect(req *libuv.Connect, status c.Int) { +// onConnect is the libuv callback for a successful connection +func onConnect(req *libuv.Connect, status c.Int) { //conn := (*ConnData)(req.Data) //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) if status < 0 { c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) return } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(allocBuffer, onRead) } -// AllocBuffer allocates a buffer for reading from a socket -func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { +// allocBuffer allocates a buffer for reading from a socket +func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { //conn := (*ConnData)(handle.Data) //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data - conn := (*ConnData)(handle.GetData()) + conn := (*connData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) //base := make([]byte, suggestedSize) @@ -405,11 +407,11 @@ func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) } -// OnRead is the libuv callback for reading from a socket +// onRead is the libuv callback for reading from a socket // This callback function is called when data is available to be read -func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { +func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // Get the connection data associated with the stream - conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) //conn := (*ConnData)(stream.Data) //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data @@ -427,10 +429,10 @@ func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { } } -// ReadCallBack read callback function for Hyper library -func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { +// readCallBack read callback function for Hyper library +func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) - conn := (*ConnData)(userdata) + conn := (*connData)(userdata) // If there's data in the buffer if conn.ReadBufFilled > 0 { @@ -462,11 +464,11 @@ func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin return hyper.IoPending } -// OnWrite is the libuv callback for writing to a socket +// onWrite is the libuv callback for writing to a socket // Callback function called after a write operation completes -func OnWrite(req *libuv.Write, status c.Int) { +func onWrite(req *libuv.Write, status c.Int) { // Get the connection data associated with the write request - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) //conn := (*ConnData)(req.Data) //conn := (*struct{ data *ConnData })(c.Pointer(req)).data @@ -479,10 +481,10 @@ func OnWrite(req *libuv.Write, status c.Int) { } } -// WriteCallBack write callback function for Hyper library -func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { +// writeCallBack write callback function for Hyper library +func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) - conn := (*ConnData)(userdata) + conn := (*connData)(userdata) // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) //req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) @@ -492,7 +494,7 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui //req.Data = c.Pointer(conn) // Perform the asynchronous write operation - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) // If the write operation was successfully initiated if ret >= 0 { // Return the number of bytes to be written @@ -510,8 +512,8 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui return hyper.IoPending } -// OnTimeout is the libuv callback for a timeout -func OnTimeout(handle *libuv.Timer) { +// onTimeout is the libuv callback for a timeout +func onTimeout(handle *libuv.Timer) { ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) if ct.conn.IsCompleted != 1 { ct.conn.IsCompleted = 1 @@ -521,25 +523,25 @@ func OnTimeout(handle *libuv.Timer) { (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) } -// NewIoWithConnReadWrite creates a new IO with read and write callbacks -func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { +// newIoWithConnReadWrite creates a new IO with read and write callbacks +func newIoWithConnReadWrite(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) - hyperIo.SetRead(ReadCallBack) - hyperIo.SetWrite(WriteCallBack) + hyperIo.SetRead(readCallBack) + hyperIo.SetWrite(writeCallBack) return hyperIo } -// SetTaskId Set TaskId to the task's userdata as a unique identifier -func SetTaskId(task *hyper.Task, userData TaskId) { +// setTaskId Set taskId to the task's userdata as a unique identifier +func setTaskId(task *hyper.Task, userData taskId) { var data = userData task.SetUserdata(unsafe.Pointer(uintptr(data))) } -// CheckTaskType checks the task type -func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { +// checkTaskType checks the task type +func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { - case Send: + case sending: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake task error!\n")) return Fail((*hyper.Error)(task.Value())) @@ -548,7 +550,7 @@ func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case ReceiveResp: + case receiveResp: if task.Type() == hyper.TaskError { c.Printf(c.Str("send task error!\n")) return Fail((*hyper.Error)(task.Value())) @@ -558,13 +560,13 @@ func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case ReceiveRespBody: + case receiveRespBody: if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) return Fail((*hyper.Error)(task.Value())) } return nil - case NotSet: + case notSet: } return fmt.Errorf("unexpected TaskId\n") } @@ -617,7 +619,7 @@ func CloseChannels(rc requestAndChan, pc *persistConn) { } // FreeConnData frees the connection data -func FreeConnData(conn *ConnData) { +func FreeConnData(conn *connData) { if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -631,3 +633,49 @@ func FreeConnData(conn *ConnData) { conn.ReadBuf.Base = nil } } + +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} + +func nop() {} + +/*// alternateRoundTripper returns the alternate RoundTripper to use +// for this request if the Request's URL scheme requires one, +// or nil for the normal case of using the Transport. +func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { + if !t.useRegisteredProtocol(req) { + return nil + } + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + return altProto[req.URL.Scheme] +} + +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} +*/ + +func idnaASCIIFromURL(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v + } + return addr +} \ No newline at end of file diff --git a/x/http/util.go b/x/http/util.go new file mode 100644 index 0000000..674f481 --- /dev/null +++ b/x/http/util.go @@ -0,0 +1,146 @@ +package http + +import ( + "strings" + "unicode" +) + +/** + * Copied from the libraries that llgo cannot be used + */ + +var isTokenTable = [127]bool{ // httpguts.isTokenTable + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +func IsTokenRune(r rune) bool { // httpguts.IsTokenRune + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { // ascii.IsPrint + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// ToLower returns the lowercase version of s if s is ASCII and printable. +func ToLower(s string) (lower string, ok bool) { // ascii.ToLower + if !IsPrint(s) { + return "", false + } + return strings.ToLower(s), true +} + +// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func EqualFold(s, t string) bool { // ascii.EqualFold + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { // ascii.lower + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// Is returns whether s is ASCII. +func Is(s string) bool { // ascii.Is + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} From 1bad20a6ef7df2776b12daefb9bbe2e32b94ce2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 19 Aug 2024 18:32:58 +0800 Subject: [PATCH 11/21] WIP(x/http/client): http.PostForm() & some function improvements --- x/http/_demo/postform/postform.go | 31 ++ x/http/_demo/timeout/timeout.go | 4 +- x/http/client.go | 12 +- x/http/request.go | 85 +++- x/http/transport.go | 644 +++++++++++++++++++++++++++--- x/http/util.go | 87 ++++ 6 files changed, 788 insertions(+), 75 deletions(-) create mode 100644 x/http/_demo/postform/postform.go diff --git a/x/http/_demo/postform/postform.go b/x/http/_demo/postform/postform.go new file mode 100644 index 0000000..5315ca9 --- /dev/null +++ b/x/http/_demo/postform/postform.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "io" + "net/url" + + "github.com/goplus/llgo/x/http" +) + +func main() { + formData := url.Values{ + "name": {"John Doe"}, + "email": {"johndoe@example.com"}, + } + + resp, err := http.PostForm("http://httpbin.org/post", formData) + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/_demo/timeout/timeout.go b/x/http/_demo/timeout/timeout.go index 6eece04..42f8bf8 100644 --- a/x/http/_demo/timeout/timeout.go +++ b/x/http/_demo/timeout/timeout.go @@ -10,8 +10,8 @@ import ( func main() { client := &http.Client{ - //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - Timeout: time.Second * 5, + Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + //Timeout: time.Second * 5, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { diff --git a/x/http/client.go b/x/http/client.go index 31362a9..4fc6e41 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -68,6 +68,14 @@ func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, return c.Do(req) } +func PostForm(url string, data url.Values) (resp *Response, err error) { + return DefaultClient.PostForm(url, data) +} + +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + func (c *Client) Do(req *Request) (*Response, error) { return c.do(req) } @@ -474,7 +482,6 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi if deadline.IsZero() { return nop, alwaysFalse } - // TODO(spongehah) todo: map[string]github.com/goplus/llgo/x/http.RoundTripper //knownTransport := knownRoundTripperImpl(rt, req) oldCtx := req.Context() @@ -552,8 +559,7 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { return t.Before(d) } -/* -// knownRoundTripperImpl reports whether rt is a RoundTripper that's +/*// knownRoundTripperImpl reports whether rt is a RoundTripper that's // maintained by the Go team and known to implement the latest // optional semantics (notably contexts). The Request is used // to check whether this particular request is using an alternate protocol, diff --git a/x/http/request.go b/x/http/request.go index f6e6f16..38b74bb 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -45,7 +45,34 @@ type Request struct { var defaultChunkSize uintptr = 8192 -func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { +// NewRequest wraps NewRequestWithContext using context.Background. +func NewRequest(method, url string, body io.Reader) (*Request, error) { + return NewRequestWithContext(context.Background(), method, url, body) +} + +// NewRequestWithContext returns a new Request given a method, URL, and +// optional body. +// +// If the provided body is also an io.Closer, the returned +// Request.Body is set to body and will be closed by the Client +// methods Do, Post, and PostForm, and Transport.RoundTrip. +// +// NewRequestWithContext returns a Request suitable for use with +// Client.Do or Transport.RoundTrip. To create a request for use with +// testing a Server Handler, either use the NewRequest function in the +// net/http/httptest package, use ReadRequest, or manually update the +// Request fields. For an outgoing client request, the context +// controls the entire lifetime of a request and its response: +// obtaining a connection, sending the request, and reading the +// response headers and body. See the Request type's documentation for +// the difference between inbound and outbound request fields. +// +// If body is of type *bytes.Buffer, *bytes.Reader, or +// *strings.Reader, the returned request's ContentLength is set to its +// exact value (instead of -1), GetBody is populated (so 307 and 308 +// redirects can replay the body), and Body is set to NoBody if the +// ContentLength is 0. +func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.Reader) (*Request, error) { if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -69,7 +96,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) req := &Request{ - //ctx: ctx, + ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", @@ -131,10 +158,49 @@ func printInformational(userdata c.Pointer, resp *hyper.Response) { fmt.Println("Informational (1xx): ", status) } +type postReq struct { + req *Request + buf []byte +} + func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - req := (*Request)(userdata) - buffer := make([]byte, defaultChunkSize) - n, err := req.Body.Read(buffer) + req := (*postReq)(userdata) + n, err := req.req.Body.Read(req.buf) + if err != nil { + if err == io.EOF { + *chunk = nil + return hyper.PollReady + } + fmt.Println("error reading upload file: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) + return hyper.PollReady + } + if n == 0 { + *chunk = nil + return hyper.PollReady + } + + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} + +func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + type buf struct { + data *uint8 + len uintptr + Unused [16]byte + } + req := (*postReq)(userdata) + buffer := &buf{ + data: &req.buf[0], + len: uintptr(len(req.buf)), + } + + *chunk = (*hyper.Buf)(c.Pointer(buffer)) + n, err := req.req.Body.Read(req.buf) if err != nil { if err == io.EOF { *chunk = nil @@ -144,7 +210,6 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In return hyper.PollError } if n > 0 { - *chunk = hyper.CopyBuf(&buffer[0], uintptr(n)) return hyper.PollReady } if n == 0 { @@ -152,7 +217,7 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In return hyper.PollReady } - fmt.Printf("error reading upload file: %s\n", c.GoString(c.Strerror(os.Errno))) + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) return hyper.PollError } @@ -180,7 +245,11 @@ func newHyperRequest(req *Request) (*hyper.Request, error) { hyperReq.OnInformational(printInformational, nil) hyperReqBody := hyper.NewBody() - hyperReqBody.SetUserdata(c.Pointer(req)) + reqData := &postReq{ + req: req, + buf: make([]byte, 3), + } + hyperReqBody.SetUserdata(c.Pointer(reqData)) hyperReqBody.SetDataFunc(setPostData) hyperReq.SetBody(hyperReqBody) } diff --git a/x/http/transport.go b/x/http/transport.go index f9bfa46..e54b357 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -1,9 +1,12 @@ package http import ( + "context" + "errors" "fmt" "io" "net/url" + "sync" "sync/atomic" "unsafe" @@ -26,9 +29,21 @@ type connData struct { } type Transport struct { - altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme + reqMu sync.Mutex + reqCanceler map[cancelKey]func(error) + //Proxy func(*Request) (*url.URL, error) + + // MaxConnsPerHost optionally limits the total number of + // connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + // + // Zero means no limit. + MaxConnsPerHost int } +var DefaultTransport RoundTripper = &Transport{} + // taskId The unique identifier of the next task polled from the executor type taskId c.Int @@ -43,15 +58,13 @@ const ( defaultHTTPPort = "80" ) -var DefaultTransport RoundTripper = &Transport{} - // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { // alt optionally specifies the TLS NextProto RoundTripper. // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. - //alt RoundTripper + alt RoundTripper //br *bufio.Reader // from conn //bw *bufio.Writer // to conn //nwrite int64 // bytes written @@ -94,47 +107,331 @@ type freeChan struct { freech chan struct{} } +// A cancelKey is the key of the reqCanceler map. +// We wrap the *Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *Request +} + +// transportRequest is a wrapper around a *Request that adds +// optional extra headers to write and stores any error to return +// from roundTrip. +type transportRequest struct { + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil + //trace *httptrace.ClientTrace // optional + cancelKey cancelKey + + mu sync.Mutex // guards err + err error // first setError value for mapRoundTripError to consider +} + +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} + +// alternateRoundTripper returns the alternate RoundTripper to use +// for this request if the Request's URL scheme requires one, +// or nil for the normal case of using the Transport. +func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { + if !t.useRegisteredProtocol(req) { + return nil + } + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + return altProto[req.URL.Scheme] +} + func (t *Transport) RoundTrip(req *Request) (*Response, error) { - pconn, err := t.getConn(req) - if err != nil { + //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + //ctx := req.Context() + //trace := httptrace.ContextClientTrace(ctx) + + if req.URL == nil { + req.closeBody() + return nil, errors.New("http: nil Request.URL") + } + if req.Header == nil { + req.closeBody() + return nil, errors.New("http: nil Request.Header") + } + scheme := req.URL.Scheme + isHTTP := scheme == "http" || scheme == "https" + if isHTTP { + for k, vv := range req.Header { + if !ValidHeaderFieldName(k) { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid header field name %q", k) + } + for _, v := range vv { + if !ValidHeaderFieldValue(v) { + req.closeBody() + // Don't include the value in the error, because it may be sensitive. + return nil, fmt.Errorf("net/http: invalid header field value for %q", k) + } + } + } + } + + origReq := req + cancelKey := cancelKey{origReq} + req = setupRewindBody(req) + + if altRT := t.alternateRoundTripper(req); altRT != nil { + if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { + return resp, err + } + var err error + req, err = rewindBody(req) + if err != nil { + return nil, err + } + } + if !isHTTP { + req.closeBody() + return nil, badStringError("unsupported protocol scheme", scheme) + } + if req.Method != "" && !validMethod(req.Method) { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid method %q", req.Method) + } + if req.URL.Host == "" { + req.closeBody() + return nil, errors.New("http: no Host in request URL") + } + + for { + // TODO(spongehah) timeout: because of that ctx not initialized ( initialized in setRequestCancel() ) + //select { + //case <-ctx.Done(): + // req.closeBody() + // return nil, ctx.Err() + //default: + //} + + // treq gets modified by roundTrip, so we need to recreate for each retry. + //treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} + treq := &transportRequest{Request: req, cancelKey: cancelKey} + cm, err := t.connectMethodForRequest(treq) + if err != nil { + req.closeBody() + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(treq, cm) + if err != nil { + t.setReqCanceler(cancelKey, nil) + req.closeBody() + return nil, err + } + + var resp *Response + if pconn.alt != nil { + // HTTP/2 path. + t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest + resp, err = pconn.alt.RoundTrip(req) + } else { + resp, err = pconn.roundTrip(treq) + } + if err == nil { + resp.Request = origReq + return resp, nil + } + + // Failed. Clean up and determine whether to retry. + // TODO(spongehah) Retry & ConnPool return nil, err } - var resp *Response - resp, err = pconn.roundTrip(req) - if err != nil { +} + +func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { + //req := treq.Request + //trace := treq.trace + //ctx := req.Context() + //if trace != nil && trace.GetConn != nil { + // trace.GetConn(cm.addr()) + //} + + w := &wantConn{ + cm: cm, + key: cm.key(), + //ctx: ctx, + ready: make(chan struct{}, 1), + beforeDial: testHookPrePendingDial, + afterDial: testHookPostPendingDial, + } + defer func() { + if err != nil { + w.cancel(t, err) + } + }() + + // TODO(spongehah) ConnPool + //// Queue for idle connection. + //if delivered := t.queueForIdleConn(w); delivered { + // pc := w.pc + // // Trace only for HTTP/1. + // // HTTP/2 calls trace.GotConn itself. + // if pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) + // } + // // set request canceler to some non-nil function so we + // // can detect whether it was cleared between now and when + // // we enter roundTrip + // t.setReqCanceler(treq.cancelKey, func(error) {}) + // return pc, nil + //} + + cancelc := make(chan error, 1) + t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) + + // Queue for permission to dial. + t.queueForDial(w) + + // Wait for completion or cancellation. + select { + case <-w.ready: + // Trace success but only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + //if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) + //} + if w.err != nil { + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + default: + // return below + } + } + return w.pc, w.err + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } return nil, err } - return resp, nil } -func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { - host := req.URL.Hostname() - port := req.URL.Port() +// queueForDial queues w to wait for permission to begin dialing. +// Once w receives permission to dial, it will do so in a separate goroutine. +func (t *Transport) queueForDial(w *wantConn) { + w.beforeDial() + + go t.dialConnFor(w) + // TODO(spongehah) MaxConnsPerHost + //if t.MaxConnsPerHost <= 0 { + // go t.dialConnFor(w) + // return + //} + + //t.connsPerHostMu.Lock() + //defer t.connsPerHostMu.Unlock() + // + //if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { + // if t.connsPerHost == nil { + // t.connsPerHost = make(map[connectMethodKey]int) + // } + // t.connsPerHost[w.key] = n + 1 + // go t.dialConnFor(w) + // return + //} + // + //if t.connsPerHostWait == nil { + // t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) + //} + //q := t.connsPerHostWait[w.key] + //q.cleanFront() + //q.pushBack(w) + //t.connsPerHostWait[w.key] = q +} + +// dialConnFor dials on behalf of w and delivers the result to w. +// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. +// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. +func (t *Transport) dialConnFor(w *wantConn) { + defer w.afterDial() + + pc, err := t.dialConn(w.ctx, w.cm) + w.tryDeliver(pc, err) + // TODO(spongehah) ConnPool + //delivered := w.tryDeliver(pc, err) + //if err == nil && (!delivered || pc.alt != nil) { + // // pconn was not passed to w, + // // or it is HTTP/2 and can be shared. + // // Add to the idle connection pool. + // t.putOrCloseIdleConn(pc) + //} + //if err != nil { + // t.decConnsPerHost(w.key) + //} +} + +func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { + pconn = &persistConn{ + t: t, + reqch: make(chan requestAndChan, 1), + cancelch: make(chan freeChan, 1), + timeoutch: make(chan struct{}, 1), + //writech: make(chan writeRequest, 1), + //closech: make(chan struct{}), + } + + // TODO(spongehah) Proxy dialConn + + treq := cm.treq + host := treq.URL.Hostname() + port := treq.URL.Port() if port == "" { // Hyper only supports http port = defaultHTTPPort } loop := libuv.DefaultLoop() - //conn := (*ConnData)(c.Calloc(1, unsafe.Sizeof(ConnData{}))) conn := new(connData) + pconn.conn = conn if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } // If timeout is set, start the timer - timeoutch := make(chan struct{}, 1) - if req.timeout > 0 { + if treq.timeout > 0 { libuv.InitTimer(loop, &conn.TimeoutTimer) ct := &connAndTimeoutChan{ conn: conn, - timeoutch: timeoutch, + timeoutch: pconn.timeoutch, } (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) - conn.TimeoutTimer.Start(onTimeout, uint64(req.timeout.Milliseconds()), 0) + conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) } libuv.InitTcp(loop, &conn.TcpHandle) - //conn.TcpHandle.Data = c.Pointer(conn) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) var hints net.AddrInfo @@ -145,26 +442,16 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { var res *net.AddrInfo status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { - close(timeoutch) + close(pconn.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } - //conn.ConnectReq.Data = c.Pointer(conn) (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { - close(timeoutch) + close(pconn.timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } - pconn = &persistConn{ - conn: conn, - t: t, - reqch: make(chan requestAndChan, 1), - cancelch: make(chan freeChan, 1), - timeoutch: timeoutch, - //writech: make(chan writeRequest, 1), - //closech: make(chan struct{}), - } net.Freeaddrinfo(res) @@ -174,18 +461,19 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { return pconn, nil } -func (pc *persistConn) roundTrip(req *Request) (*Response, error) { +func (pc *persistConn) roundTrip(req *transportRequest) (*Response, error) { + testHookEnterRoundTrip() resc := make(chan responseAndError, 1) pc.reqch <- requestAndChan{ - req: req, + req: req.Request, ch: resc, } // Determine whether timeout has occurred if pc.conn.IsCompleted == 1 { rc := <-pc.reqch // blocking // Free the resources - FreeResources(nil, nil, nil, nil, pc, rc) + freeResources(nil, nil, nil, nil, pc, rc) return nil, fmt.Errorf("request timeout\n") } select { @@ -203,7 +491,6 @@ func (pc *persistConn) roundTrip(req *Request) (*Response, error) { freech: freech, } <-freech - close(freech) return nil, fmt.Errorf("request timeout\n") } } @@ -237,9 +524,9 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { select { case fc := <-pc.cancelch: // Free the resources - FreeResources(nil, respBody, bodyWriter, exec, pc, rc) + freeResources(nil, respBody, bodyWriter, exec, pc, rc) alive = false - fc.freech <- struct{}{} + close(fc.freech) return default: task := exec.Poll() @@ -253,7 +540,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -265,7 +552,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -276,7 +563,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if sendRes != hyper.OK { rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -287,7 +574,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -299,7 +586,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -327,7 +614,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -338,14 +625,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if bodyWriter == nil { rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } _, err := bodyWriter.Write(bytes) // blocking if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } buf.Free() @@ -363,12 +650,12 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { c.Printf(c.Str("unexpected task type\n")) rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) alive = false case notSet: @@ -544,7 +831,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { case sending: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake task error!\n")) - return Fail((*hyper.Error)(task.Value())) + return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskClientConn { return fmt.Errorf("unexpected task type\n") @@ -553,7 +840,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { case receiveResp: if task.Type() == hyper.TaskError { c.Printf(c.Str("send task error!\n")) - return Fail((*hyper.Error)(task.Value())) + return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskResponse { c.Printf(c.Str("unexpected task type\n")) @@ -563,7 +850,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { case receiveRespBody: if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) - return Fail((*hyper.Error)(task.Value())) + return fail((*hyper.Error)(task.Value())) } return nil case notSet: @@ -571,8 +858,8 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected TaskId\n") } -// Fail prints the error details and panics -func Fail(err *hyper.Error) error { +// fail prints the error details and panics +func fail(err *hyper.Error) error { if err != nil { c.Printf(c.Str("error code: %d\n"), err.Code()) // grab the error details @@ -588,8 +875,8 @@ func Fail(err *hyper.Error) error { return nil } -// FreeResources frees the resources -func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { +// freeResources frees the resources +func freeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { // Cleaning up before exiting if task != nil { task.Free() @@ -604,13 +891,13 @@ func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWr exec.Free() } (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) - FreeConnData(pc.conn) + freeConnData(pc.conn) - CloseChannels(rc, pc) + closeChannels(rc, pc) } -// CloseChannels closes the channels -func CloseChannels(rc requestAndChan, pc *persistConn) { +// closeChannels closes the channels +func closeChannels(rc requestAndChan, pc *persistConn) { // Closing the channel close(rc.ch) close(pc.reqch) @@ -618,8 +905,8 @@ func CloseChannels(rc requestAndChan, pc *persistConn) { close(pc.cancelch) } -// FreeConnData frees the connection data -func FreeConnData(conn *connData) { +// freeConnData frees the connection data +func freeConnData(conn *connData) { if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -643,9 +930,23 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } +func nop() {} + +// ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. +var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") + +var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") + var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} -func nop() {} +// errRequestCanceled is set to be identical to the one from h2 to facilitate +// testing. +var errRequestCanceled = http2errRequestCanceled + +// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var http2errRequestCanceled = errors.New("net/http: request canceled") +var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? /*// alternateRoundTripper returns the alternate RoundTripper to use // for this request if the Request's URL scheme requires one, @@ -678,4 +979,223 @@ func idnaASCIIFromURL(url *url.URL) string { addr = v } return addr -} \ No newline at end of file +} + +type readTrackingBody struct { + io.ReadCloser + didRead bool + didClose bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + +// testHooks. Always non-nil. +var ( + testHookEnterRoundTrip = nop + testHookWaitResLoop = nop + testHookRoundTripRetried = nop + testHookPrePendingDial = nop + testHookPostPendingDial = nop + + testHookMu sync.Locker = fakeLocker{} // guards following + testHookReadLoopBeforeNextRead = nop +) + +// fakeLocker is a sync.Locker which does nothing. It's used to guard +// test-only fields when not under test, to avoid runtime atomic +// overhead. +type fakeLocker struct{} + +func (fakeLocker) Lock() {} +func (fakeLocker) Unlock() {} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been read at all yet. +func setupRewindBody(req *Request) *Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *Request) (rewound *Request, err error) { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { + return req, nil // nothing to rewind + } + if !req.Body.(*readTrackingBody).didClose { + req.closeBody() + } + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// connectMethod is the map key (in its String form) for keeping persistent +// TCP connections alive for subsequent HTTP requests. +// +// A connect method may be of the following types: +// +// connectMethod.key().String() Description +// ------------------------------ ------------------------- +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy +// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com +// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com +// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com +// https://proxy.com|http https to proxy, http to anywhere after that +type connectMethod struct { + _ incomparable + proxyURL *url.URL // nil for no proxy, else full proxy URL + targetScheme string // "http" or "https" + // If proxyURL specifies an http or https proxy, and targetScheme is http (not https), + // then targetAddr is not included in the connect method key, because the socket can + // be reused for different targetAddr values. + targetAddr string + treq *transportRequest // optional + onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + // TODO(spongehah) canonicalAddr & Proxy + //cm.targetAddr = canonicalAddr(treq.URL) + //if t.Proxy != nil { + // cm.proxyURL, err = t.Proxy(treq.Request) + //} + cm.treq = treq + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string + onlyH1 bool +} + +// A wantConn records state about a wanted connection +// (that is, an active call to getConn). +// The conn may be gotten by dialing or by finding an idle connection, +// or a cancellation may make the conn no longer wanted. +// These three options are racing against each other and use +// wantConn to coordinate and agree about the winning outcome. +type wantConn struct { + cm connectMethod + key connectMethodKey // cm.key() + ctx context.Context // context for dial + ready chan struct{} // closed when pc, err pair is delivered + + // hooks for testing to know when dials are done + // beforeDial is called in the getConn goroutine when the dial is queued. + // afterDial is called when the dial is completed or canceled. + beforeDial func() + afterDial func() + + mu sync.Mutex // protects pc, err, close(ready) + pc *persistConn + err error +} + +// waiting reports whether w is still waiting for an answer (connection or error). +func (w *wantConn) waiting() bool { + select { + case <-w.ready: + return false + default: + return true + } +} + +// tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. +func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + + if w.pc != nil || w.err != nil { + return false + } + + w.pc = pc + w.err = err + if w.pc == nil && w.err == nil { + panic("net/http: internal error: misuse of tryDeliver") + } + close(w.ready) + return true +} + +// cancel marks w as no longer wanting a result (for example, due to cancellation). +// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. +func (w *wantConn) cancel(t *Transport, err error) { + w.mu.Lock() + if w.pc == nil && w.err == nil { + close(w.ready) // catch misbehavior in future delivery + } + //pc := w.pc + w.pc = nil + w.err = err + w.mu.Unlock() + + // TODO(spongehah) ConnPool + //if pc != nil { + // t.putOrCloseIdleConn(pc) + //} +} + +func (cm *connectMethod) key() connectMethodKey { + proxyStr := "" + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { + targetAddr = "" + } + } + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + onlyH1: cm.onlyH1, + } +} diff --git a/x/http/util.go b/x/http/util.go index 674f481..e5d2d03 100644 --- a/x/http/util.go +++ b/x/http/util.go @@ -94,6 +94,93 @@ func IsTokenRune(r rune) bool { // httpguts.IsTokenRune return i < len(isTokenTable) && isTokenTable[i] } +// ValidHeaderFieldName reports whether v is a valid HTTP/1.x header name. +// HTTP/2 imposes the additional restriction that uppercase ASCII +// letters are not allowed. +// +// RFC 7230 says: +// +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// token = 1*tchar +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +func ValidHeaderFieldName(v string) bool { // httpguts.ValidHeaderFieldName + if len(v) == 0 { + return false + } + for i := 0; i < len(v); i++ { + if !isTokenTable[v[i]] { + return false + } + } + return true +} + +// ValidHeaderFieldValue reports whether v is a valid "field-value" according to +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 : +// +// message-header = field-name ":" [ field-value ] +// field-value = *( field-content | LWS ) +// field-content = +// +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 : +// +// TEXT = +// LWS = [CRLF] 1*( SP | HT ) +// CTL = +// +// RFC 7230 says: +// +// field-value = *( field-content / obs-fold ) +// obj-fold = N/A to http2, and deprecated +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// VCHAR = "any visible [USASCII] character" +// +// http2 further says: "Similarly, HTTP/2 allows header field values +// that are not valid. While most of the values that can be encoded +// will not alter header field parsing, carriage return (CR, ASCII +// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII +// 0x0) might be exploited by an attacker if they are translated +// verbatim. Any request or response that contains a character not +// permitted in a header field value MUST be treated as malformed +// (Section 8.1.2.6). Valid characters are defined by the +// field-content ABNF rule in Section 3.2 of [RFC7230]." +// +// This function does not (yet?) properly handle the rejection of +// strings that begin or end with SP or HTAB. +func ValidHeaderFieldValue(v string) bool { // httpguts.ValidHeaderFieldValue + for i := 0; i < len(v); i++ { + b := v[i] + if isCTL(b) && !isLWS(b) { + return false + } + } + return true +} + +// isLWS reports whether b is linear white space, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// +// LWS = [CRLF] 1*( SP | HT ) +func isLWS(b byte) bool { return b == ' ' || b == '\t' } // httpguts.isLWS + +// isCTL reports whether b is a control byte, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// +// CTL = +func isCTL(b byte) bool { // httpguts.isCTL + const del = 0x7f // a CTL + return b < ' ' || b == del +} + // IsPrint returns whether s is ASCII and printable according to // https://tools.ietf.org/html/rfc20#section-4.2. func IsPrint(s string) bool { // ascii.IsPrint From d644d432ffe6ef80dffa221179f41aec73fc1ab3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 20 Aug 2024 10:23:08 +0800 Subject: [PATCH 12/21] refactor(x/http/client): Move the file directory --- x/http/_demo/upload/upload.go | 24 ------------- x/{ => net}/http/_demo/get/get.go | 2 +- x/{ => net}/http/_demo/headers/headers.go | 2 +- x/{ => net}/http/_demo/post/post.go | 2 +- x/{ => net}/http/_demo/postform/postform.go | 2 +- x/{ => net}/http/_demo/redirect/redirect.go | 2 +- .../http/_demo/server/redirectServer.go | 0 x/{ => net}/http/_demo/timeout/timeout.go | 2 +- x/{ => net}/http/_demo/upload/example.txt | 0 x/net/http/_demo/upload/upload.go | 35 +++++++++++++++++++ x/{ => net}/http/client.go | 0 x/{ => net}/http/clone.go | 0 x/{ => net}/http/cookie.go | 0 x/{ => net}/http/header.go | 0 x/{ => net}/http/http.go | 0 x/{ => net}/http/jar.go | 0 x/{ => net}/http/request.go | 0 x/{ => net}/http/response.go | 0 x/{ => net}/http/transfer.go | 0 x/{ => net}/http/transport.go | 0 x/{ => net}/http/util.go | 0 21 files changed, 41 insertions(+), 30 deletions(-) delete mode 100644 x/http/_demo/upload/upload.go rename x/{ => net}/http/_demo/get/get.go (89%) rename x/{ => net}/http/_demo/headers/headers.go (96%) rename x/{ => net}/http/_demo/post/post.go (91%) rename x/{ => net}/http/_demo/postform/postform.go (92%) rename x/{ => net}/http/_demo/redirect/redirect.go (90%) rename x/{ => net}/http/_demo/server/redirectServer.go (100%) rename x/{ => net}/http/_demo/timeout/timeout.go (92%) rename x/{ => net}/http/_demo/upload/example.txt (100%) create mode 100644 x/net/http/_demo/upload/upload.go rename x/{ => net}/http/client.go (100%) rename x/{ => net}/http/clone.go (100%) rename x/{ => net}/http/cookie.go (100%) rename x/{ => net}/http/header.go (100%) rename x/{ => net}/http/http.go (100%) rename x/{ => net}/http/jar.go (100%) rename x/{ => net}/http/request.go (100%) rename x/{ => net}/http/response.go (100%) rename x/{ => net}/http/transfer.go (100%) rename x/{ => net}/http/transport.go (100%) rename x/{ => net}/http/util.go (100%) diff --git a/x/http/_demo/upload/upload.go b/x/http/_demo/upload/upload.go deleted file mode 100644 index c6bb391..0000000 --- a/x/http/_demo/upload/upload.go +++ /dev/null @@ -1,24 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/goplus/llgoexamples/x/http" -) - -func main() { - resp, err := http.Post("http://httpbin.org/post", "", nil) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(resp.Status) - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(string(body)) - defer resp.Body.Close() -} diff --git a/x/http/_demo/get/get.go b/x/net/http/_demo/get/get.go similarity index 89% rename from x/http/_demo/get/get.go rename to x/net/http/_demo/get/get.go index bff1bd1..79c18ba 100644 --- a/x/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go similarity index 96% rename from x/http/_demo/headers/headers.go rename to x/net/http/_demo/headers/headers.go index 2672a66..71d42b7 100644 --- a/x/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/post/post.go b/x/net/http/_demo/post/post.go similarity index 91% rename from x/http/_demo/post/post.go rename to x/net/http/_demo/post/post.go index 4958a8e..f169dfc 100644 --- a/x/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/postform/postform.go b/x/net/http/_demo/postform/postform.go similarity index 92% rename from x/http/_demo/postform/postform.go rename to x/net/http/_demo/postform/postform.go index 5315ca9..1636786 100644 --- a/x/http/_demo/postform/postform.go +++ b/x/net/http/_demo/postform/postform.go @@ -5,7 +5,7 @@ import ( "io" "net/url" - "github.com/goplus/llgo/x/http" + "github.com/goplus/llgo/x/net/http" ) func main() { diff --git a/x/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go similarity index 90% rename from x/http/_demo/redirect/redirect.go rename to x/net/http/_demo/redirect/redirect.go index 48465b7..e4fdb92 100644 --- a/x/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/server/redirectServer.go b/x/net/http/_demo/server/redirectServer.go similarity index 100% rename from x/http/_demo/server/redirectServer.go rename to x/net/http/_demo/server/redirectServer.go diff --git a/x/http/_demo/timeout/timeout.go b/x/net/http/_demo/timeout/timeout.go similarity index 92% rename from x/http/_demo/timeout/timeout.go rename to x/net/http/_demo/timeout/timeout.go index 42f8bf8..ddb2d25 100644 --- a/x/http/_demo/timeout/timeout.go +++ b/x/net/http/_demo/timeout/timeout.go @@ -5,7 +5,7 @@ import ( "io" "time" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/upload/example.txt b/x/net/http/_demo/upload/example.txt similarity index 100% rename from x/http/_demo/upload/example.txt rename to x/net/http/_demo/upload/example.txt diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go new file mode 100644 index 0000000..3dab514 --- /dev/null +++ b/x/net/http/_demo/upload/upload.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "io" + "os" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + url := "http://httpbin.org/post" + filePath := "/Users/spongehah/go/src/llgo/x/http/_demo/upload/example.txt" // Replace with your file path + + file, err := os.Open(filePath) + if err != nil { + fmt.Println("Error opening file:", err) + return + } + defer file.Close() + + resp, err := http.Post(url, "application/octet-stream", file) + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(respBody)) +} diff --git a/x/http/client.go b/x/net/http/client.go similarity index 100% rename from x/http/client.go rename to x/net/http/client.go diff --git a/x/http/clone.go b/x/net/http/clone.go similarity index 100% rename from x/http/clone.go rename to x/net/http/clone.go diff --git a/x/http/cookie.go b/x/net/http/cookie.go similarity index 100% rename from x/http/cookie.go rename to x/net/http/cookie.go diff --git a/x/http/header.go b/x/net/http/header.go similarity index 100% rename from x/http/header.go rename to x/net/http/header.go diff --git a/x/http/http.go b/x/net/http/http.go similarity index 100% rename from x/http/http.go rename to x/net/http/http.go diff --git a/x/http/jar.go b/x/net/http/jar.go similarity index 100% rename from x/http/jar.go rename to x/net/http/jar.go diff --git a/x/http/request.go b/x/net/http/request.go similarity index 100% rename from x/http/request.go rename to x/net/http/request.go diff --git a/x/http/response.go b/x/net/http/response.go similarity index 100% rename from x/http/response.go rename to x/net/http/response.go diff --git a/x/http/transfer.go b/x/net/http/transfer.go similarity index 100% rename from x/http/transfer.go rename to x/net/http/transfer.go diff --git a/x/http/transport.go b/x/net/http/transport.go similarity index 100% rename from x/http/transport.go rename to x/net/http/transport.go diff --git a/x/http/util.go b/x/net/http/util.go similarity index 100% rename from x/http/util.go rename to x/net/http/util.go From 2944a9d8dc73633ea3afba7362969b75c7b3c651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 20 Aug 2024 18:14:05 +0800 Subject: [PATCH 13/21] WIP(x/http/client): http proxy & 100-continue & KeepAlive & gzip & some code improvement --- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 32 + x/net/http/_demo/postform/postform.go | 2 +- x/net/http/_demo/upload/upload.go | 12 +- x/net/http/request.go | 42 +- x/net/http/response.go | 7 + x/net/http/transfer.go | 25 - x/net/http/transport.go | 931 ++++++++++++++---- x/net/http/util.go | 25 + x/net/ipsock.go | 24 + 9 files changed, 894 insertions(+), 206 deletions(-) create mode 100644 x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go create mode 100644 x/net/ipsock.go diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go new file mode 100644 index 0000000..63cedbc --- /dev/null +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + client := &http.Client{ + //Transport: &http.Transport{ + // MaxConnsPerHost: 2, + //}, + } + req, err := http.NewRequest("GET", "https://www.baidu.com", nil) + resp, err := client.Do(req) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + fmt.Println(resp.Proto) + resp.PrintHeaders() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/net/http/_demo/postform/postform.go b/x/net/http/_demo/postform/postform.go index 1636786..eae4d6e 100644 --- a/x/net/http/_demo/postform/postform.go +++ b/x/net/http/_demo/postform/postform.go @@ -5,7 +5,7 @@ import ( "io" "net/url" - "github.com/goplus/llgo/x/net/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index 3dab514..86b57e9 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -10,7 +10,7 @@ import ( func main() { url := "http://httpbin.org/post" - filePath := "/Users/spongehah/go/src/llgo/x/http/_demo/upload/example.txt" // Replace with your file path + filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path file, err := os.Open(filePath) if err != nil { @@ -19,7 +19,15 @@ func main() { } defer file.Close() - resp, err := http.Post(url, "application/octet-stream", file) + client := &http.Client{} + req, err := http.NewRequest("POST", url, file) + if err != nil { + fmt.Println(err) + return + } + req.Header.Set("expect", "100-continue") + resp, err := client.Do(req) + if err != nil { fmt.Println(err) return diff --git a/x/net/http/request.go b/x/net/http/request.go index 38b74bb..b44b00b 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -3,6 +3,7 @@ package http import ( "bytes" "context" + "errors" "fmt" "io" "net/textproto" @@ -73,6 +74,16 @@ func NewRequest(method, url string, body io.Reader) (*Request, error) { // redirects can replay the body), and Body is set to NoBody if the // ContentLength is 0. func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.Reader) (*Request, error) { + // TODO(spongehah) Hyper only supports http + isHttpPrefix := strings.HasPrefix(urlStr, "http://") + isHttpsPrefix := strings.HasPrefix(urlStr, "https://") + if !isHttpPrefix && !isHttpsPrefix { + urlStr = "http://" + urlStr + } + if isHttpsPrefix { + urlStr = "http://" + strings.TrimPrefix(urlStr, "https://") + } + if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -82,9 +93,9 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R if !validMethod(method) { return nil, fmt.Errorf("net/http: invalid method %q", method) } - //if ctx == nil { - // return nil, errors.New("net/http: nil Context") - //} + if ctx == nil { + return nil, errors.New("net/http: nil Context") + } u, err := url.Parse(urlStr) if err != nil { return nil, err @@ -241,8 +252,10 @@ func newHyperRequest(req *Request) (*hyper.Request, error) { } if method == "POST" && req.Body != nil { - req.Header.Set("expect", "100-continue") - hyperReq.OnInformational(printInformational, nil) + // 100-continue + if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() { + hyperReq.OnInformational(printInformational, nil) + } hyperReqBody := hyper.NewBody() reqData := &postReq{ @@ -285,6 +298,17 @@ func (req *Request) setHeaders(hyperReq *hyper.Request) error { return nil } +func (r *Request) expectsContinue() bool { + return hasToken(r.Header.get("Expect"), "100-continue") +} + +func (r *Request) wantsClose() bool { + if r.Close { + return true + } + return hasToken(r.Header.get("Connection"), "close") +} + func (r *Request) closeBody() error { if r.Body == nil { return nil @@ -354,6 +378,14 @@ func (r *Request) Cookies() []*Cookie { return readCookies(r.Header, "") } +// ProtoAtLeast reports whether the HTTP protocol used +// in the request is at least major.minor. +func (r *Request) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + + // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // diff --git a/x/net/http/response.go b/x/net/http/response.go index 174d2fc..32b5723 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -82,3 +82,10 @@ func fixPragmaCacheControl(header Header) { func (r *Response) Cookies() []*Cookie { return readSetCookies(r.Header) } + +// isProtocolSwitchHeader reports whether the request or response header +// is for a protocol switch. +func isProtocolSwitchHeader(h Header) bool { + return h.Get("Upgrade") != "" && + HeaderValuesContainsToken(h["Connection"], "Upgrade") +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index ac50296..0787270 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -388,31 +388,6 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { return hasClose } -// HeaderValuesContainsToken reports whether any string in values -// contains the provided token, ASCII case-insensitively. -func HeaderValuesContainsToken(values []string, token string) bool { - for _, v := range values { - if headerValueContainsToken(v, token) { - return true - } - } - return false -} - -// headerValueContainsToken reports whether v (assumed to be a -// 0#element, in the ABNF extension described in RFC 7230 section 7) -// contains token amongst its comma-separated tokens, ASCII -// case-insensitively. -func headerValueContainsToken(v string, token string) bool { - for comma := strings.IndexByte(v, ','); comma != -1; comma = strings.IndexByte(v, ',') { - if tokenEqual(trimOWS(v[:comma]), token) { - return true - } - v = v[comma+1:] - } - return tokenEqual(trimOWS(v), token) -} - // tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. func tokenEqual(t1, t2 string) bool { if len(t1) != len(t2) { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index e54b357..e5467e2 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -14,25 +14,36 @@ import ( "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" + xnet "github.com/goplus/llgoexamples/x/net" "github.com/goplus/llgoexamples/rust/hyper" ) -type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - TimeoutTimer libuv.Timer - IsCompleted int - ReadBufFilled uintptr - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} - type Transport struct { altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - //Proxy func(*Request) (*url.URL, error) + Proxy func(*Request) (*url.URL, error) + + connsPerHostMu sync.Mutex + connsPerHost map[connectMethodKey]int + connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + + // DisableKeepAlives, if true, disables HTTP keep-alives and + // will only use the connection to the server for a single + // HTTP request. + // + // This is unrelated to the similarly named TCP keep-alives. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool // MaxConnsPerHost optionally limits the total number of // connections per host, including connections in the dialing, @@ -42,17 +53,15 @@ type Transport struct { MaxConnsPerHost int } -var DefaultTransport RoundTripper = &Transport{} - -// taskId The unique identifier of the next task polled from the executor -type taskId c.Int - -const ( - notSet taskId = iota - sending - receiveResp - receiveRespBody -) +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY +// and NO_PROXY (or the lowercase versions thereof). +var DefaultTransport RoundTripper = &Transport{ + //Proxy: ProxyFromEnvironment, + Proxy: nil, +} const ( defaultHTTPPort = "80" @@ -65,16 +74,34 @@ type persistConn struct { // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. alt RoundTripper + //br *bufio.Reader // from conn //bw *bufio.Writer // to conn //nwrite int64 // bytes written //writech chan writeRequest // written by roundTrip; read by writeLoop //closech chan struct{} // closed when conn closed - conn *connData - t *Transport - reqch chan requestAndChan // written by roundTrip; read by readLoop + + t *Transport + cacheKey connectMethodKey + conn *connData + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; read by readLoop + closech chan struct{} // closed when conn closed + writeLoopDone chan struct{} // closed when write loop ends + cancelch chan freeChan timeoutch chan struct{} + + isProxy bool + mu sync.Mutex // guards following fields + numExpectedResponses int + closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled + broken bool // an error has happened on this connection; marked broken so it's not reused. + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -83,9 +110,17 @@ type persistConn struct { type incomparable [0]func() type requestAndChan struct { - _ incomparable - req *Request - ch chan responseAndError // unbuffered; always send in select on callerGone + _ incomparable + req *Request + cancelKey cancelKey + ch chan responseAndError // unbuffered; always send in select on callerGone + + // whether the Transport (as opposed to the user client code) + // added the Accept-Encoding gzip header. If the Transport + // set it, only then do we transparently decode the gzip. + addedGzip bool + + callerGone <-chan struct{} // closed when roundTrip caller has returned } // responseAndError is how the goroutine reading from an HTTP/1 server @@ -127,6 +162,13 @@ type transportRequest struct { err error // first setError value for mapRoundTripError to consider } +func (tr *transportRequest) extraHeaders() Header { + if tr.extra == nil { + tr.extra = make(Header) + } + return tr.extra +} + // useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { @@ -328,6 +370,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } } return w.pc, w.err + // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn //case <-req.Context().Done(): @@ -345,32 +388,30 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi func (t *Transport) queueForDial(w *wantConn) { w.beforeDial() - go t.dialConnFor(w) - // TODO(spongehah) MaxConnsPerHost - //if t.MaxConnsPerHost <= 0 { - // go t.dialConnFor(w) - // return - //} + if t.MaxConnsPerHost <= 0 { + go t.dialConnFor(w) + return + } - //t.connsPerHostMu.Lock() - //defer t.connsPerHostMu.Unlock() - // - //if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { - // if t.connsPerHost == nil { - // t.connsPerHost = make(map[connectMethodKey]int) - // } - // t.connsPerHost[w.key] = n + 1 - // go t.dialConnFor(w) - // return - //} - // - //if t.connsPerHostWait == nil { - // t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) - //} - //q := t.connsPerHostWait[w.key] - //q.cleanFront() - //q.pushBack(w) - //t.connsPerHostWait[w.key] = q + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + + if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[w.key] = n + 1 + go t.dialConnFor(w) + return + } + + if t.connsPerHostWait == nil { + t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) + } + q := t.connsPerHostWait[w.key] + q.cleanFront() + q.pushBack(w) + t.connsPerHostWait[w.key] = q } // dialConnFor dials on behalf of w and delivers the result to w. @@ -383,39 +424,176 @@ func (t *Transport) dialConnFor(w *wantConn) { w.tryDeliver(pc, err) // TODO(spongehah) ConnPool //delivered := w.tryDeliver(pc, err) + // Handle undelivered or shareable connections //if err == nil && (!delivered || pc.alt != nil) { // // pconn was not passed to w, // // or it is HTTP/2 and can be shared. // // Add to the idle connection pool. // t.putOrCloseIdleConn(pc) //} + + // TODO(spongehah) decConnsPerHost + // If an error occurs during the dialing process, the connection count for that host is decreased. + // This ensures that the connection count remains accurate even in cases where the dial attempt fails. //if err != nil { // t.decConnsPerHost(w.key) //} } +// decConnsPerHost decrements the per-host connection count for key, +// which may in turn give a different waiting goroutine permission to dial. +//func (t *Transport) decConnsPerHost(key connectMethodKey) { +// if t.MaxConnsPerHost <= 0 { +// return +// } +// +// t.connsPerHostMu.Lock() +// defer t.connsPerHostMu.Unlock() +// n := t.connsPerHost[key] +// if n == 0 { +// // Shouldn't happen, but if it does, the counting is buggy and could +// // easily lead to a silent deadlock, so report the problem loudly. +// panic("net/http: internal error: connCount underflow") +// } +// +// // Can we hand this count to a goroutine still waiting to dial? +// // (Some goroutines on the wait list may have timed out or +// // gotten a connection another way. If they're all gone, +// // we don't want to kick off any spurious dial operations.) +// if q := t.connsPerHostWait[key]; q.len() > 0 { +// done := false +// for q.len() > 0 { +// w := q.popFront() +// if w.waiting() { +// go t.dialConnFor(w) +// done = true +// break +// } +// } +// if q.len() == 0 { +// delete(t.connsPerHostWait, key) +// } else { +// // q is a value (like a slice), so we have to store +// // the updated q back into the map. +// t.connsPerHostWait[key] = q +// } +// if done { +// return +// } +// } +// +// // Otherwise, decrement the recorded count. +// if n--; n == 0 { +// delete(t.connsPerHost, key) +// } else { +// t.connsPerHost[key] = n +// } +//} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ - t: t, - reqch: make(chan requestAndChan, 1), - cancelch: make(chan freeChan, 1), - timeoutch: make(chan struct{}, 1), + t: t, + cacheKey: cm.key(), + reqch: make(chan requestAndChan, 1), + cancelch: make(chan freeChan, 1), + timeoutch: make(chan struct{}, 1), + closech: make(chan struct{}, 1), + writeLoopDone: make(chan struct{}, 1), //writech: make(chan writeRequest, 1), //closech: make(chan struct{}), } - // TODO(spongehah) Proxy dialConn + //if cm.scheme() == "https" && t.hasCustomTLSDialer() { + // var err error + // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) + // if err != nil { + // return nil, wrapErr(err) + // } + // if tc, ok := pconn.conn.(*tls.Conn); ok { + // // Handshake here, in case DialTLS didn't. TLSNextProto below + // // depends on it for knowing the connection state. + // if trace != nil && trace.TLSHandshakeStart != nil { + // trace.TLSHandshakeStart() + // } + // if err := tc.HandshakeContext(ctx); err != nil { + // go pconn.conn.Close() + // if trace != nil && trace.TLSHandshakeDone != nil { + // trace.TLSHandshakeDone(tls.ConnectionState{}, err) + // } + // return nil, err + // } + // cs := tc.ConnectionState() + // if trace != nil && trace.TLSHandshakeDone != nil { + // trace.TLSHandshakeDone(cs, nil) + // } + // pconn.tlsState = &cs + // } + //} else { + //conn, err := t.dial(ctx, "tcp", cm.addr()) + conn, err := t.dial(ctx, pconn, cm) + if err != nil { + return nil, err + } + pconn.conn = conn + //if cm.scheme() == "https" { + // var firstTLSHost string + // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { + // return nil, wrapErr(err) + // } + // if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { + // return nil, wrapErr(err) + // } + //} + //} + + // TODO(spongehah) Proxy(https/sock5) + // Proxy setup. + switch { + case cm.proxyURL == nil: + // Do nothing. Not using a proxy. + // case cm.proxyURL.Scheme == "socks5": + case cm.targetScheme == "http": + pconn.isProxy = true + if pa := cm.proxyAuth(); pa != "" { + pconn.mutateHeaderFunc = func(h Header) { + h.Set("Proxy-Authorization", pa) + } + } + // case cm.targetScheme == "https": + } + //if cm.proxyURL != nil && cm.targetScheme == "https" { + // if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { + // return nil, err + // } + //} + // + //if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + // if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { + // alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) + // if e, ok := alt.(erringRoundTripper); ok { + // // pconn.conn was closed by next (http2configureTransports.upgradeFn). + // return nil, e.RoundTripErr() + // } + // return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil + // } + //} + + if conn.IsCompleted != 1 { + go pconn.readWriteLoop(libuv.DefaultLoop()) + } + return pconn, nil +} + +func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMethod) (*connData, error) { treq := cm.treq host := treq.URL.Hostname() port := treq.URL.Port() if port == "" { - // Hyper only supports http port = defaultHTTPPort } loop := libuv.DefaultLoop() conn := new(connData) - pconn.conn = conn if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } @@ -454,44 +632,136 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } net.Freeaddrinfo(res) - - if pconn.conn.IsCompleted != 1 { - go pconn.readWriteLoop(loop) - } - return pconn, nil + return conn, nil } -func (pc *persistConn) roundTrip(req *transportRequest) (*Response, error) { +func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() - resc := make(chan responseAndError, 1) + if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { + // TODO(spongehah) ConnPool + //pc.t.putOrCloseIdleConn(pc) + return nil, errRequestCanceled + } + pc.mu.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.mu.Unlock() - pc.reqch <- requestAndChan{ - req: req.Request, - ch: resc, + if headerFn != nil { + headerFn(req.extraHeaders()) } - // Determine whether timeout has occurred - if pc.conn.IsCompleted == 1 { - rc := <-pc.reqch // blocking - // Free the resources - freeResources(nil, nil, nil, nil, pc, rc) - return nil, fmt.Errorf("request timeout\n") + + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + req.Method != "HEAD" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // https://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + requestedGzip = true + req.extraHeaders().Set("Accept-Encoding", "gzip") } - select { - case re := <-resc: - if (re.res == nil) == (re.err == nil) { - return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) + + // The 100-continue operation in Hyper is handled in the newHyperRequest function. + + // Keep-Alive + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + + gone := make(chan struct{}) + defer close(gone) + + defer func() { + if err != nil { + pc.t.setReqCanceler(req.cancelKey, nil) } - if re.err != nil { - return nil, re.err + }() + + const debugRoundTrip = false // Debug switch provided for developers + + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + + // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). + startBytesWritten := pc.nwrite + + // Send the request to readWriteLoop(). + resc := make(chan responseAndError, 1) + + pc.reqch <- requestAndChan{ + req: req.Request, + cancelKey: req.cancelKey, + ch: resc, + addedGzip: requestedGzip, + callerGone: gone, + } + + //var respHeaderTimer <-chan time.Time + //cancelChan := req.Request.Cancel + //ctxDoneChan := req.Context().Done() + pcClosed := pc.closech + canceled := false + + for { + testHookWaitResLoop() + + // Determine whether timeout has occurred + if pc.conn.IsCompleted == 1 { + rc := <-pc.reqch // blocking + // Free the resources + freeResources(nil, nil, nil, nil, pc, rc) + return nil, fmt.Errorf("request timeout\n") } - return re.res, nil - case <-pc.timeoutch: - freech := make(chan struct{}, 1) - pc.cancelch <- freeChan{ - freech: freech, + select { + //case err := <-writeErrCh: + case <-pcClosed: + pcClosed = nil + if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { + if debugRoundTrip { + //req.logf("closech recv: %T %#v", pc.closed, pc.closed) + } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) + } + //case <-respHeaderTimer: + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) + } + if debugRoundTrip { + //req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil + // TODO(spongehah) cancel(pc.roundTrip) + //case <-cancelChan: + case <-pc.timeoutch: + freech := make(chan struct{}, 1) + pc.cancelch <- freeChan{ + freech: freech, + } + <-freech + return nil, fmt.Errorf("request timeout\n") } - <-freech - return nil, fmt.Errorf("request timeout\n") } } @@ -667,6 +937,22 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} } +type connData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + TimeoutTimer libuv.Timer + IsCompleted int + ReadBufFilled uintptr + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + +func (conn *connData) Close() error { + freeConnData(conn) + return nil +} + // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { //conn := (*ConnData)(req.Data) @@ -819,6 +1105,16 @@ func newIoWithConnReadWrite(connData *connData) *hyper.Io { return hyperIo } +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int + +const ( + notSet taskId = iota + sending + receiveResp + receiveRespBody +) + // setTaskId Set taskId to the task's userdata as a unique identifier func setTaskId(task *hyper.Task, userData taskId) { var data = userData @@ -921,24 +1217,30 @@ func freeConnData(conn *connData) { } } -type httpError struct { - err string - timeout bool -} - -func (e *httpError) Error() string { return e.err } -func (e *httpError) Timeout() bool { return e.timeout } -func (e *httpError) Temporary() bool { return true } +// ---------------------------------------------------------- -func nop() {} +// error values for debugging and testing, not seen by users. +var ( + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") + + // errServerClosedIdle is not seen by users for idempotent requests, but may be + // seen by a user if the server shuts down an idle connection and sends its FIN + // in flight with already-written POST body bytes from the client. + // See https://github.com/golang/go/issues/19943#issuecomment-355607646 + errServerClosedIdle = errors.New("http: server closed idle connection") +) // ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") - var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") -var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} - // errRequestCanceled is set to be identical to the one from h2 to facilitate // testing. var errRequestCanceled = http2errRequestCanceled @@ -947,6 +1249,67 @@ var errRequestCanceled = http2errRequestCanceled // exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. var http2errRequestCanceled = errors.New("net/http: request canceled") var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? +// errCallerOwnsConn is an internal sentinel error used when we hand +// off a writable response.Body to the caller. We use this to prevent +// closing a net.Conn that is now owned by the caller. +var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn") + +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +// fakeLocker is a sync.Locker which does nothing. It's used to guard +// test-only fields when not under test, to avoid runtime atomic +// overhead. +type fakeLocker struct{} + +func (fakeLocker) Lock() {} +func (fakeLocker) Unlock() {} + +// nothingWrittenError wraps a write errors which ended up writing zero bytes. +type nothingWrittenError struct { + error +} + +func (nwe nothingWrittenError) Unwrap() error { + return nwe.error +} + +// transportReadFromServerError is used by Transport.readLoop when the +// 1 byte peek read fails and we're actually anticipating a response. +// Usually this is just due to the inherent keep-alive shut down race, +// where the server closed the connection at the same time the client +// wrote. The underlying err field is usually io.EOF or some +// ECONNRESET sort of thing which varies by platform. But it might be +// the user's custom net.Conn.Read error too, so we carry it along for +// them to return from Transport.RoundTrip. +type transportReadFromServerError struct { + err error +} + +func (e transportReadFromServerError) Unwrap() error { return e.err } +func (e transportReadFromServerError) Error() string { + return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err) +} + +func nop() {} + +// testHooks. Always non-nil. +var ( + testHookEnterRoundTrip = nop + testHookWaitResLoop = nop + testHookRoundTripRetried = nop + testHookPrePendingDial = nop + testHookPostPendingDial = nop + + testHookMu sync.Locker = fakeLocker{} // guards following + testHookReadLoopBeforeNextRead = nop +) /*// alternateRoundTripper returns the alternate RoundTripper to use // for this request if the Request's URL scheme requires one, @@ -973,12 +1336,160 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool { } */ -func idnaASCIIFromURL(url *url.URL) string { - addr := url.Hostname() - if v, err := idnaASCII(addr); err == nil { - addr = v +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) } - return addr + cm.treq = treq + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// replaceReqCanceler replaces an existing cancel function. If there is no cancel function +// for the request, we don't set the function and return false. +// Since CancelRequest will clear the canceler, we can use the return value to detect if +// the request was canceled since the last setReqCancel call. +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { + t.reqMu.Lock() + defer t.reqMu.Unlock() + _, ok := t.reqCanceler[key] + if !ok { + return false + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } + return true +} + +func (pc *persistConn) cancelRequest(err error) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.canceledErr = err + pc.closeLocked(errRequestCanceled) +} + +// close closes the underlying TCP connection and closes +// the pc.closech channel. +// +// The provided err is only for testing and debugging; in normal +// circumstances it should never be seen by users. +func (pc *persistConn) close(err error) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.closeLocked(err) +} + +func (pc *persistConn) closeLocked(err error) { + if err == nil { + panic("nil error") + } + pc.broken = true + if pc.closed == nil { + pc.closed = err + // TODO(spongehah) decConnsPerHost + //pc.t.decConnsPerHost(pc.cacheKey) + // Close HTTP/1 (pc.alt == nil) connection. + // HTTP/2 closes its connection itself. + if pc.alt == nil { + if err != errCallerOwnsConn { + pc.conn.Close() + } + close(pc.closech) + } + } + pc.mutateHeaderFunc = nil +} + +// mapRoundTripError returns the appropriate error value for +// persistConn.roundTrip. +// +// The provided err is the first error that (*persistConn).roundTrip +// happened to receive from its select statement. +// +// The startBytesWritten value should be the value of pc.nwrite before the roundTrip +// started writing the request. +func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error { + if err == nil { + return nil + } + + // Wait for the writeLoop goroutine to terminate to avoid data + // races on callers who mutate the request on failure. + // + // When resc in pc.roundTrip and hence rc.ch receives a responseAndError + // with a non-nil error it implies that the persistConn is either closed + // or closing. Waiting on pc.writeLoopDone is hence safe as all callers + // close closech which in turn ensures writeLoop returns. + <-pc.writeLoopDone + + // If the request was canceled, that's better than network + // failures that were likely the result of tearing down the + // connection. + if cerr := pc.canceled(); cerr != nil { + return cerr + } + + // See if an error was set explicitly. + req.mu.Lock() + reqErr := req.err + req.mu.Unlock() + if reqErr != nil { + return reqErr + } + + if err == errServerClosedIdle { + // Don't decorate + return err + } + + if _, ok := err.(transportReadFromServerError); ok { + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + // Don't decorate + return err + } + if pc.isBroken() { + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err) + } + return err +} + +// canceled returns non-nil if the connection was closed due to +// CancelRequest or due to context cancellation. +func (pc *persistConn) canceled() error { + pc.mu.Lock() + defer pc.mu.Unlock() + return pc.canceledErr +} + +// isBroken reports whether this connection is in a known broken state. +func (pc *persistConn) isBroken() bool { + pc.mu.Lock() + b := pc.closed != nil + pc.mu.Unlock() + return b } type readTrackingBody struct { @@ -997,26 +1508,6 @@ func (r *readTrackingBody) Close() error { return r.ReadCloser.Close() } -// testHooks. Always non-nil. -var ( - testHookEnterRoundTrip = nop - testHookWaitResLoop = nop - testHookRoundTripRetried = nop - testHookPrePendingDial = nop - testHookPostPendingDial = nop - - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead = nop -) - -// fakeLocker is a sync.Locker which does nothing. It's used to guard -// test-only fields when not under test, to avoid runtime atomic -// overhead. -type fakeLocker struct{} - -func (fakeLocker) Lock() {} -func (fakeLocker) Unlock() {} - // setupRewindBody returns a new request with a custom body wrapper // that can report whether the body needs rewinding. // This lets rewindBody avoid an error result when the request @@ -1053,17 +1544,27 @@ func rewindBody(req *Request) (rewound *Request, err error) { return &newReq, nil } -func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - if t.reqCanceler == nil { - t.reqCanceler = make(map[cancelKey]func(error)) +var portMap = map[string]string{ + "http": "80", + "https": "443", + "socks5": "1080", +} + +func idnaASCIIFromURL(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) + return addr +} + +// canonicalAddr returns url.Host but always with a ":port" suffix. +func canonicalAddr(url *url.URL) string { + port := url.Port() + if port == "" { + port = portMap[url.Scheme] } + return xnet.JoinHostPort(idnaASCIIFromURL(url), port) } // connectMethod is the map key (in its String form) for keeping persistent @@ -1094,16 +1595,51 @@ type connectMethod struct { onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - cm.targetScheme = treq.URL.Scheme - // TODO(spongehah) canonicalAddr & Proxy - //cm.targetAddr = canonicalAddr(treq.URL) - //if t.Proxy != nil { - // cm.proxyURL, err = t.Proxy(treq.Request) - //} - cm.treq = treq - cm.onlyH1 = treq.requiresHTTP1() - return cm, err +func (cm *connectMethod) key() connectMethodKey { + proxyStr := "" + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { + targetAddr = "" + } + } + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + onlyH1: cm.onlyH1, + } +} + +// scheme returns the first hop scheme: http, https, or socks5 +func (cm *connectMethod) scheme() string { + if cm.proxyURL != nil { + return cm.proxyURL.Scheme + } + return cm.targetScheme +} + +// addr returns the first hop "host:port" to which we need to TCP connect. +func (cm *connectMethod) addr() string { + if cm.proxyURL != nil { + return canonicalAddr(cm.proxyURL) + } + return cm.targetAddr +} + +// proxyAuth returns the Proxy-Authorization header to set +// on requests, if applicable. +func (cm *connectMethod) proxyAuth() string { + if cm.proxyURL == nil { + return "" + } + if u := cm.proxyURL.User; u != nil { + username := u.Username() + password, _ := u.Password() + return "Basic " + basicAuth(username, password) + } + return "" } // connectMethodKey is the map key version of connectMethod, with a @@ -1137,6 +1673,24 @@ type wantConn struct { err error } +// cancel marks w as no longer wanting a result (for example, due to cancellation). +// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. +func (w *wantConn) cancel(t *Transport, err error) { + w.mu.Lock() + if w.pc == nil && w.err == nil { + close(w.ready) // catch misbehavior in future delivery + } + //pc := w.pc + w.pc = nil + w.err = err + w.mu.Unlock() + + // TODO(spongehah) ConnPool + //if pc != nil { + // t.putOrCloseIdleConn(pc) + //} +} + // waiting reports whether w is still waiting for an answer (connection or error). func (w *wantConn) waiting() bool { select { @@ -1165,37 +1719,68 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { return true } -// cancel marks w as no longer wanting a result (for example, due to cancellation). -// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. -func (w *wantConn) cancel(t *Transport, err error) { - w.mu.Lock() - if w.pc == nil && w.err == nil { - close(w.ready) // catch misbehavior in future delivery - } - //pc := w.pc - w.pc = nil - w.err = err - w.mu.Unlock() +// A wantConnQueue is a queue of wantConns. +type wantConnQueue struct { + // This is a queue, not a deque. + // It is split into two stages - head[headPos:] and tail. + // popFront is trivial (headPos++) on the first stage, and + // pushBack is trivial (append) on the second stage. + // If the first stage is empty, popFront can swap the + // first and second stages to remedy the situation. + // + // This two-stage split is analogous to the use of two lists + // in Okasaki's purely functional queue but without the + // overhead of reversing the list when swapping stages. + head []*wantConn + headPos int + tail []*wantConn +} - // TODO(spongehah) ConnPool - //if pc != nil { - // t.putOrCloseIdleConn(pc) - //} +// len returns the number of items in the queue. +func (q *wantConnQueue) len() int { + return len(q.head) - q.headPos + len(q.tail) } -func (cm *connectMethod) key() connectMethodKey { - proxyStr := "" - targetAddr := cm.targetAddr - if cm.proxyURL != nil { - proxyStr = cm.proxyURL.String() - if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { - targetAddr = "" +// pushBack adds w to the back of the queue. +func (q *wantConnQueue) pushBack(w *wantConn) { + q.tail = append(q.tail, w) +} + +// popFront removes and returns the wantConn at the front of the queue. +func (q *wantConnQueue) popFront() *wantConn { + if q.headPos >= len(q.head) { + if len(q.tail) == 0 { + return nil } + // Pick up tail as new head, clear tail. + q.head, q.headPos, q.tail = q.tail, 0, q.head[:0] } - return connectMethodKey{ - proxy: proxyStr, - scheme: cm.targetScheme, - addr: targetAddr, - onlyH1: cm.onlyH1, + w := q.head[q.headPos] + q.head[q.headPos] = nil + q.headPos++ + return w +} + +// peekFront returns the wantConn at the front of the queue without removing it. +func (q *wantConnQueue) peekFront() *wantConn { + if q.headPos < len(q.head) { + return q.head[q.headPos] + } + if len(q.tail) > 0 { + return q.tail[0] + } + return nil +} + +// cleanFront pops any wantConns that are no longer waiting from the head of the +// queue, reporting whether any were popped. +func (q *wantConnQueue) cleanFront() (cleaned bool) { + for { + w := q.peekFront() + if w == nil || w.waiting() { + return cleaned + } + q.popFront() + cleaned = true } } diff --git a/x/net/http/util.go b/x/net/http/util.go index e5d2d03..f2efb70 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -181,6 +181,31 @@ func isCTL(b byte) bool { // httpguts.isCTL return b < ' ' || b == del } +// HeaderValuesContainsToken reports whether any string in values +// contains the provided token, ASCII case-insensitively. +func HeaderValuesContainsToken(values []string, token string) bool { // httpguts.HeaderValuesContainsToken + for _, v := range values { + if headerValueContainsToken(v, token) { + return true + } + } + return false +} + +// headerValueContainsToken reports whether v (assumed to be a +// 0#element, in the ABNF extension described in RFC 7230 section 7) +// contains token amongst its comma-separated tokens, ASCII +// case-insensitively. +func headerValueContainsToken(v string, token string) bool { // httpguts.headerValueContainsToken + for comma := strings.IndexByte(v, ','); comma != -1; comma = strings.IndexByte(v, ',') { + if tokenEqual(trimOWS(v[:comma]), token) { + return true + } + v = v[comma+1:] + } + return tokenEqual(trimOWS(v), token) +} + // IsPrint returns whether s is ASCII and printable according to // https://tools.ietf.org/html/rfc20#section-4.2. func IsPrint(s string) bool { // ascii.IsPrint diff --git a/x/net/ipsock.go b/x/net/ipsock.go new file mode 100644 index 0000000..855a864 --- /dev/null +++ b/x/net/ipsock.go @@ -0,0 +1,24 @@ +package net + +// JoinHostPort combines host and port into a network address of the +// form "host:port". If host contains a colon, as found in literal +// IPv6 addresses, then JoinHostPort returns "[host]:port". +// +// See func Dial for a description of the host and port parameters. +func JoinHostPort(host, port string) string { + // We assume that host is a literal IPv6 address if host has + // colons. + if IndexByteString(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + +func IndexByteString(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} From d05548e7adde869260c1a48fc2432cbdeffec0b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Wed, 21 Aug 2024 18:58:13 +0800 Subject: [PATCH 14/21] WIP(x/http/client): Optimize readWriteLoop and make some code adjustments --- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 6 +- x/net/http/request.go | 6 +- x/net/http/response.go | 26 + x/net/http/transfer.go | 2 +- x/net/http/transport.go | 822 ++++++++++-------- 5 files changed, 504 insertions(+), 358 deletions(-) diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 63cedbc..882bdc1 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -9,9 +9,9 @@ import ( func main() { client := &http.Client{ - //Transport: &http.Transport{ - // MaxConnsPerHost: 2, - //}, + Transport: &http.Transport{ + MaxConnsPerHost: 2, + }, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) resp, err := client.Do(req) diff --git a/x/net/http/request.go b/x/net/http/request.go index b44b00b..cb50936 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -385,7 +385,6 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } - // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // @@ -425,6 +424,11 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + func idnaASCII(v string) (string, error) { // TODO: Consider removing this check after verifying performance is okay. // Right now punycode verification, length checks, context checks, and the diff --git a/x/net/http/response.go b/x/net/http/response.go index 32b5723..d2a7dd5 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "strconv" + "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -63,6 +64,19 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } +// appendToResponseBody BodyForeachCallback function: Process the response body +func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + writer := (*io.PipeWriter)(userdata) + bufLen := chunk.Len() + bytes := unsafe.Slice(chunk.Bytes(), bufLen) + _, err := writer.Write(bytes) + if err != nil { + fmt.Println("Error writing to response body:", err) + return hyper.IterBreak + } + return hyper.IterContinue +} + // RFC 7234, section 5.4: Should treat // // Pragma: no-cache @@ -89,3 +103,15 @@ func isProtocolSwitchHeader(h Header) bool { return h.Get("Upgrade") != "" && HeaderValuesContainsToken(h["Connection"], "Upgrade") } + +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 0787270..823ef7d 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -367,7 +367,7 @@ func bodyAllowedForStatus(status int) bool { return true } -// Determine whether to hang up after sending a request and body, or +// Determine whether to hang up after write a request and body, or // receiving a response and body // 'header' is the request headers. func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index e5467e2..99003eb 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -14,10 +14,25 @@ import ( "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - xnet "github.com/goplus/llgoexamples/x/net" + xnet "github.com/goplus/llgo/x/net" "github.com/goplus/llgoexamples/rust/hyper" ) +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY +// and NO_PROXY (or the lowercase versions thereof). +var DefaultTransport RoundTripper = &Transport{ + //Proxy: ProxyFromEnvironment, + Proxy: nil, +} + +// DefaultMaxIdleConnsPerHost is the default value of Transport's +// MaxIdleConnsPerHost. +const DefaultMaxIdleConnsPerHost = 2 +const defaultHTTPPort = "80" + type Transport struct { altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex @@ -53,55 +68,11 @@ type Transport struct { MaxConnsPerHost int } -// DefaultTransport is the default implementation of Transport and is -// used by DefaultClient. It establishes network connections as needed -// and caches them for reuse by subsequent calls. It uses HTTP proxies -// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY -// and NO_PROXY (or the lowercase versions thereof). -var DefaultTransport RoundTripper = &Transport{ - //Proxy: ProxyFromEnvironment, - Proxy: nil, -} - -const ( - defaultHTTPPort = "80" -) - -// persistConn wraps a connection, usually a persistent one -// (but may be used for non-keep-alive requests as well) -type persistConn struct { - // alt optionally specifies the TLS NextProto RoundTripper. - // This is used for HTTP/2 today and future protocols later. - // If it's non-nil, the rest of the fields are unused. - alt RoundTripper - - //br *bufio.Reader // from conn - //bw *bufio.Writer // to conn - //nwrite int64 // bytes written - //writech chan writeRequest // written by roundTrip; read by writeLoop - //closech chan struct{} // closed when conn closed - - t *Transport - cacheKey connectMethodKey - conn *connData - nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; read by readLoop - closech chan struct{} // closed when conn closed - writeLoopDone chan struct{} // closed when write loop ends - - cancelch chan freeChan - timeoutch chan struct{} - - isProxy bool - mu sync.Mutex // guards following fields - numExpectedResponses int - closed error // set non-nil when conn is closed, before closech is closed - canceledErr error // set non-nil if conn is canceled - broken bool // an error has happened on this connection; marked broken so it's not reused. - // mutateHeaderFunc is an optional func to modify extra - // headers on each outbound request before it's written. (the - // original Request given to RoundTrip is not modified) - mutateHeaderFunc func(Header) +// A cancelKey is the key of the reqCanceler map. +// We wrap the *Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *Request } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -123,6 +94,15 @@ type requestAndChan struct { callerGone <-chan struct{} // closed when roundTrip caller has returned } +// A writeRequest is sent by the caller's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { @@ -142,11 +122,56 @@ type freeChan struct { freech chan struct{} } -// A cancelKey is the key of the reqCanceler map. -// We wrap the *Request in this type since we want to use the original request, -// not any transient one created by roundTrip. -type cancelKey struct { - req *Request +type readTrackingBody struct { + io.ReadCloser + didRead bool + didClose bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been readRespLineAndHeader at all yet. +func setupRewindBody(req *Request) *Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *Request) (rewound *Request, err error) { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { + return req, nil // nothing to rewind + } + if !req.Body.(*readTrackingBody).didClose { + req.closeBody() + } + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil } // transportRequest is a wrapper around a *Request that adds @@ -169,16 +194,54 @@ func (tr *transportRequest) extraHeaders() Header { return tr.extra } -// useRegisteredProtocol reports whether an alternate protocol (as registered -// with Transport.RegisterProtocol) should be respected for this request. -func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. +func (tr *transportRequest) setError(err error) { + tr.mu.Lock() + if tr.err == nil { + tr.err = err + } + tr.mu.Unlock() +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) + } + cm.treq = treq + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// replaceReqCanceler replaces an existing cancel function. If there is no cancel function +// for the request, we don't set the function and return false. +// Since CancelRequest will clear the canceler, we can use the return value to detect if +// the request was canceled since the last setReqCancel call. +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { + t.reqMu.Lock() + defer t.reqMu.Unlock() + _, ok := t.reqCanceler[key] + if !ok { return false } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } return true } @@ -193,6 +256,21 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { return altProto[req.URL.Scheme] } +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} + +// ---------------------------------------------------------- + func (t *Transport) RoundTrip(req *Request) (*Response, error) { //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() @@ -432,69 +510,69 @@ func (t *Transport) dialConnFor(w *wantConn) { // t.putOrCloseIdleConn(pc) //} - // TODO(spongehah) decConnsPerHost // If an error occurs during the dialing process, the connection count for that host is decreased. // This ensures that the connection count remains accurate even in cases where the dial attempt fails. - //if err != nil { - // t.decConnsPerHost(w.key) - //} + if err != nil { + t.decConnsPerHost(w.key) + } } // decConnsPerHost decrements the per-host connection count for key, // which may in turn give a different waiting goroutine permission to dial. -//func (t *Transport) decConnsPerHost(key connectMethodKey) { -// if t.MaxConnsPerHost <= 0 { -// return -// } -// -// t.connsPerHostMu.Lock() -// defer t.connsPerHostMu.Unlock() -// n := t.connsPerHost[key] -// if n == 0 { -// // Shouldn't happen, but if it does, the counting is buggy and could -// // easily lead to a silent deadlock, so report the problem loudly. -// panic("net/http: internal error: connCount underflow") -// } -// -// // Can we hand this count to a goroutine still waiting to dial? -// // (Some goroutines on the wait list may have timed out or -// // gotten a connection another way. If they're all gone, -// // we don't want to kick off any spurious dial operations.) -// if q := t.connsPerHostWait[key]; q.len() > 0 { -// done := false -// for q.len() > 0 { -// w := q.popFront() -// if w.waiting() { -// go t.dialConnFor(w) -// done = true -// break -// } -// } -// if q.len() == 0 { -// delete(t.connsPerHostWait, key) -// } else { -// // q is a value (like a slice), so we have to store -// // the updated q back into the map. -// t.connsPerHostWait[key] = q -// } -// if done { -// return -// } -// } -// -// // Otherwise, decrement the recorded count. -// if n--; n == 0 { -// delete(t.connsPerHost, key) -// } else { -// t.connsPerHost[key] = n -// } -//} +func (t *Transport) decConnsPerHost(key connectMethodKey) { + if t.MaxConnsPerHost <= 0 { + return + } + + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + n := t.connsPerHost[key] + if n == 0 { + // Shouldn't happen, but if it does, the counting is buggy and could + // easily lead to a silent deadlock, so report the problem loudly. + panic("net/http: internal error: connCount underflow") + } + + // Can we hand this count to a goroutine still waiting to dial? + // (Some goroutines on the wait list may have timed out or + // gotten a connection another way. If they're all gone, + // we don't want to kick off any spurious dial operations.) + if q := t.connsPerHostWait[key]; q.len() > 0 { + done := false + for q.len() > 0 { + w := q.popFront() + if w.waiting() { + go t.dialConnFor(w) + done = true + break + } + } + if q.len() == 0 { + delete(t.connsPerHostWait, key) + } else { + // q is a value (like a slice), so we have to store + // the updated q back into the map. + t.connsPerHostWait[key] = q + } + if done { + return + } + } + + // Otherwise, decrement the recorded count. + if n--; n == 0 { + delete(t.connsPerHost, key) + } else { + t.connsPerHost[key] = n + } +} func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, cacheKey: cm.key(), reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), cancelch: make(chan freeChan, 1), timeoutch: make(chan struct{}, 1), closech: make(chan struct{}, 1), @@ -535,6 +613,24 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } pconn.conn = conn + + // hyper specific + // Hookup the IO + hyperIo := newIoWithConnReadWrite(conn) + // We need an executor generally to poll futures + exec := hyper.NewExecutor() + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(exec) + pconn.io = hyperIo + pconn.exec = exec + pconn.opts = opts + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + setTaskId(handshakeTask, write) + // Let's wait for the handshake to finish... + exec.Push(handshakeTask) + //if cm.scheme() == "https" { // var firstTLSHost string // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { @@ -702,10 +798,11 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). startBytesWritten := pc.nwrite + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} // Send the request to readWriteLoop(). resc := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{ req: req.Request, cancelKey: req.cancelKey, @@ -731,7 +828,22 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, fmt.Errorf("request timeout\n") } select { - //case err := <-writeErrCh: + case err := <-writeErrCh: + if debugRoundTrip { + //req.logf("writeErrCh resv: %T/%#v", err, err) + } + if err != nil { + pc.close(fmt.Errorf("write error: %w", err)) + return nil, pc.mapRoundTripError(req, startBytesWritten, err) + } + //if d := pc.t.ResponseHeaderTimeout; d > 0 { + // if debugRoundTrip { + // //req.logf("starting timer for %v", d) + // } + // timer := time.NewTimer(d) + // defer timer.Stop() // prevent leaks + // respHeaderTimer = timer.C + //} case <-pcClosed: pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { @@ -768,49 +880,52 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { - // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) - - // We need an executor generally to poll futures - exec := hyper.NewExecutor() - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(exec) + defer close(pc.writeLoopDone) - handshakeTask := hyper.Handshake(hyperIo, opts) - setTaskId(handshakeTask, sending) - - // Let's wait for the handshake to finish... - exec.Push(handshakeTask) + const debugReadWriteLoop = true // Debug switch provided for developers // The polling state machine! - //for { // Poll all ready tasks and act on them... - rc := <-pc.reqch // blocking alive := true var bodyWriter *io.PipeWriter - var respBody *hyper.Body = nil for alive { select { case fc := <-pc.cancelch: + if debugReadWriteLoop { + println("cancelch") + } // Free the resources - freeResources(nil, respBody, bodyWriter, exec, pc, rc) + //freeResources(nil, respBody, bodyWriter, pc.exec, pc, rc) alive = false + pc.close(errors.New("timeout error")) close(fc.freech) return + case <-pc.closech: + if debugReadWriteLoop { + println("closech") + } + return default: - task := exec.Poll() + task := pc.exec.Poll() if task == nil { loop.Run(libuv.RUN_ONCE) continue } switch (taskId)(uintptr(task.Userdata())) { - case sending: - err := checkTaskType(task, sending) + case write: + if debugReadWriteLoop { + println("write") + } + wc := <-pc.writech // blocking + + startBytesWritten := pc.nwrite + + err := checkTaskType(task, write) if err != nil { - rc.ch <- responseAndError{err: err} + wc.ch <- err // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) + pc.close(err) return } @@ -818,53 +933,110 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { task.Free() // Prepare the hyper.Request - hyperReq, err := newHyperRequest(rc.req) + hyperReq, err := newHyperRequest(wc.req.Request) + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears down + // connections and causes other + // errors. + wc.req.setError(err) + } if err != nil { - rc.ch <- responseAndError{err: err} + if pc.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } + wc.ch <- err // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) + pc.close(err) return } // Send it! sendTask := client.Send(hyperReq) - setTaskId(sendTask, receiveResp) - sendRes := exec.Push(sendTask) + setTaskId(sendTask, readRespLineAndHeader) + sendRes := pc.exec.Push(sendTask) if sendRes != hyper.OK { - rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} + wc.ch <- err // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) + pc.close(err) return } // For this example, no longer need the client client.Free() - case receiveResp: - err := checkTaskType(task, receiveResp) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + case readRespLineAndHeader: + if debugReadWriteLoop { + println("readRespLineAndHeader") + } + rc := <-pc.reqch // blocking + + closeErr := errReadLoopExiting // default value, if not changed below + defer func() { + pc.close(closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //pc.t.removeIdleConn(pc) + }() + + //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + // if err := pc.t.tryPutIdleConn(pc); err != nil { + // closeErr = err + // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + // } + // return false + // } + // if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + // } + // return true + //} + + // Read this once, before loop starts. (to avoid races in tests) + testHookMu.Lock() + testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + testHookMu.Unlock() + + pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.closeLocked(errServerClosedIdle) + pc.mu.Unlock() return } + pc.mu.Unlock() + err := checkTaskType(task, readRespLineAndHeader) // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() - resp, err := ReadResponse(hyperResp, rc.req) + var resp *Response + var respBody *hyper.Body + if err == nil { + resp, err = ReadResponse(hyperResp, rc.req) + respBody = hyperResp.Body() + resp.Body, bodyWriter = io.Pipe() + } else { + err = transportReadFromServerError{err} + closeErr = err + } + if err != nil { - rc.ch <- responseAndError{err: err} + select { + case rc.ch <- responseAndError{err: err}: + case <-rc.callerGone: + return + } // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) return } - respBody = hyperResp.Body() - resp.Body, bodyWriter = io.Pipe() - - rc.ch <- responseAndError{res: resp} - // Response has been returned, stop the timer pc.conn.IsCompleted = 1 // Stop the timer @@ -873,70 +1045,92 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) } - dataTask := respBody.Data() - setTaskId(dataTask, receiveRespBody) - exec.Push(dataTask) + pc.mu.Lock() + pc.numExpectedResponses-- + pc.mu.Unlock() - // No longer need the response - hyperResp.Free() - case receiveRespBody: - err := checkTaskType(task, receiveRespBody) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) - return + bodyWritable := resp.bodyIsWritable() + hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + + if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + alive = false } - if task.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(task.Value()) - bufLen := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) - if bodyWriter == nil { - rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) - return + if !hasBody || bodyWritable { + //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) + pc.t.replaceReqCanceler(rc.cancelKey, nil) + + // TODO(spongehah) ConnPool(readWriteLoop) + //// Put the idle conn back into the pool before we send the response + //// so if they process it quickly and make another request, they'll + //// get this same conn. But we use the unbuffered channel 'rc' + //// to guarantee that persistConn.roundTrip got out of its select + //// potentially waiting for this persistConn to close. + //alive = alive && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + if bodyWritable { + closeErr = errCallerOwnsConn } - _, err := bodyWriter.Write(bytes) // blocking - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + + select { + case rc.ch <- responseAndError{res: resp}: + case <-rc.callerGone: return } - buf.Free() - task.Free() - dataTask := respBody.Data() - setTaskId(dataTask, receiveRespBody) - exec.Push(dataTask) + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + testHookReadLoopBeforeNextRead() + continue + } + + bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) + setTaskId(bodyForeachTask, readRespBody) + pc.exec.Push(bodyForeachTask) - break + rc.ch <- responseAndError{res: resp} + + // No longer need the response + hyperResp.Free() + case readRespBody: + // A background task of reading the response body is completed + if debugReadWriteLoop { + println("readRespBody") + } + err := checkTaskType(task, readRespBody) + if err != nil { + fmt.Println(err) + pc.close(err) + return } - // We are done with the response body if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + err = errors.New("unexpected task type\n") + fmt.Println(err) + pc.close(err) return } - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) - - alive = false + // free the task + task.Free() + bodyWriter.Close() case notSet: // A background task for hyper_client completed... task.Free() } } } - //} } +// ---------------------------------------------------------- + type connData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect @@ -981,7 +1175,7 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { } // onRead is the libuv callback for reading from a socket -// This callback function is called when data is available to be read +// This callback function is called when data is available to be readRespLineAndHeader func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) @@ -1002,7 +1196,7 @@ func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { } } -// readCallBack read callback function for Hyper library +// readCallBack readRespLineAndHeader callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) conn := (*connData)(userdata) @@ -1096,7 +1290,7 @@ func onTimeout(handle *libuv.Timer) { (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) } -// newIoWithConnReadWrite creates a new IO with read and write callbacks +// newIoWithConnReadWrite creates a new IO with readRespLineAndHeader and write callbacks func newIoWithConnReadWrite(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) @@ -1110,9 +1304,9 @@ type taskId c.Int const ( notSet taskId = iota - sending - receiveResp - receiveRespBody + write + readRespLineAndHeader + readRespBody ) // setTaskId Set taskId to the task's userdata as a unique identifier @@ -1124,7 +1318,7 @@ func setTaskId(task *hyper.Task, userData taskId) { // checkTaskType checks the task type func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { - case sending: + case write: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake task error!\n")) return fail((*hyper.Error)(task.Value())) @@ -1133,7 +1327,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case receiveResp: + case readRespLineAndHeader: if task.Type() == hyper.TaskError { c.Printf(c.Str("send task error!\n")) return fail((*hyper.Error)(task.Value())) @@ -1143,7 +1337,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case receiveRespBody: + case readRespBody: if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) return fail((*hyper.Error)(task.Value())) @@ -1281,7 +1475,7 @@ func (nwe nothingWrittenError) Unwrap() error { } // transportReadFromServerError is used by Transport.readLoop when the -// 1 byte peek read fails and we're actually anticipating a response. +// 1 byte peek readRespLineAndHeader fails and we're actually anticipating a response. // Usually this is just due to the inherent keep-alive shut down race, // where the server closed the connection at the same time the client // wrote. The underlying err field is usually io.EOF or some @@ -1311,72 +1505,70 @@ var ( testHookReadLoopBeforeNextRead = nop ) -/*// alternateRoundTripper returns the alternate RoundTripper to use -// for this request if the Request's URL scheme requires one, -// or nil for the normal case of using the Transport. -func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { - if !t.useRegisteredProtocol(req) { - return nil - } - altProto, _ := t.altProto.Load().(map[string]RoundTripper) - return altProto[req.URL.Scheme] +var portMap = map[string]string{ + "http": "80", + "https": "443", + "socks5": "1080", } -// useRegisteredProtocol reports whether an alternate protocol (as registered -// with Transport.RegisterProtocol) should be respected for this request. -func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false +func idnaASCIIFromURL(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v } - return true + return addr } -*/ -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - cm.targetScheme = treq.URL.Scheme - cm.targetAddr = canonicalAddr(treq.URL) - if t.Proxy != nil { - cm.proxyURL, err = t.Proxy(treq.Request) +// canonicalAddr returns url.Host but always with a ":port" suffix. +func canonicalAddr(url *url.URL) string { + port := url.Port() + if port == "" { + port = portMap[url.Scheme] } - cm.treq = treq - cm.onlyH1 = treq.requiresHTTP1() - return cm, err + return xnet.JoinHostPort(idnaASCIIFromURL(url), port) } -func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - if t.reqCanceler == nil { - t.reqCanceler = make(map[cancelKey]func(error)) - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } -} +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + // alt optionally specifies the TLS NextProto RoundTripper. + // This is used for HTTP/2 today and future protocols later. + // If it's non-nil, the rest of the fields are unused. + alt RoundTripper -// replaceReqCanceler replaces an existing cancel function. If there is no cancel function -// for the request, we don't set the function and return false. -// Since CancelRequest will clear the canceler, we can use the return value to detect if -// the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { - t.reqMu.Lock() - defer t.reqMu.Unlock() - _, ok := t.reqCanceler[key] - if !ok { - return false - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } - return true + //br *bufio.Reader // from conn + //bw *bufio.Writer // to conn + //nwrite int64 // bytes written + //writech chan writeRequest // written by roundTrip; read by writeLoop + //closech chan struct{} // closed when conn closed + + t *Transport + cacheKey connectMethodKey + conn *connData + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; readRespLineAndHeader by readWriteLoop + writech chan writeRequest // written by roundTrip; readRespLineAndHeader by writeLoop(Already merged into reqch) + closech chan struct{} // closed when conn closed + writeLoopDone chan struct{} // closed when write loop ends + + cancelch chan freeChan + timeoutch chan struct{} + + isProxy bool + mu sync.Mutex // guards following fields + numExpectedResponses int + closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled + broken bool // an error has happened on this connection; marked broken so it's not reused. + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) + + // hyper specific + exec *hyper.Executor + opts *hyper.ClientConnOptions + io *hyper.Io } func (pc *persistConn) cancelRequest(err error) { @@ -1404,8 +1596,7 @@ func (pc *persistConn) closeLocked(err error) { pc.broken = true if pc.closed == nil { pc.closed = err - // TODO(spongehah) decConnsPerHost - //pc.t.decConnsPerHost(pc.cacheKey) + pc.t.decConnsPerHost(pc.cacheKey) // Close HTTP/1 (pc.alt == nil) connection. // HTTP/2 closes its connection itself. if pc.alt == nil { @@ -1492,81 +1683,6 @@ func (pc *persistConn) isBroken() bool { return b } -type readTrackingBody struct { - io.ReadCloser - didRead bool - didClose bool -} - -func (r *readTrackingBody) Read(data []byte) (int, error) { - r.didRead = true - return r.ReadCloser.Read(data) -} - -func (r *readTrackingBody) Close() error { - r.didClose = true - return r.ReadCloser.Close() -} - -// setupRewindBody returns a new request with a custom body wrapper -// that can report whether the body needs rewinding. -// This lets rewindBody avoid an error result when the request -// does not have GetBody but the body hasn't been read at all yet. -func setupRewindBody(req *Request) *Request { - if req.Body == nil || req.Body == NoBody { - return req - } - newReq := *req - newReq.Body = &readTrackingBody{ReadCloser: req.Body} - return &newReq -} - -// rewindBody returns a new request with the body rewound. -// It returns req unmodified if the body does not need rewinding. -// rewindBody takes care of closing req.Body when appropriate -// (in all cases except when rewindBody returns req unmodified). -func rewindBody(req *Request) (rewound *Request, err error) { - if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { - return req, nil // nothing to rewind - } - if !req.Body.(*readTrackingBody).didClose { - req.closeBody() - } - if req.GetBody == nil { - return nil, errCannotRewind - } - body, err := req.GetBody() - if err != nil { - return nil, err - } - newReq := *req - newReq.Body = &readTrackingBody{ReadCloser: body} - return &newReq, nil -} - -var portMap = map[string]string{ - "http": "80", - "https": "443", - "socks5": "1080", -} - -func idnaASCIIFromURL(url *url.URL) string { - addr := url.Hostname() - if v, err := idnaASCII(addr); err == nil { - addr = v - } - return addr -} - -// canonicalAddr returns url.Host but always with a ":port" suffix. -func canonicalAddr(url *url.URL) string { - port := url.Port() - if port == "" { - port = portMap[url.Scheme] - } - return xnet.JoinHostPort(idnaASCIIFromURL(url), port) -} - // connectMethod is the map key (in its String form) for keeping persistent // TCP connections alive for subsequent HTTP requests. // @@ -1595,6 +1711,14 @@ type connectMethod struct { onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string + onlyH1 bool +} + func (cm *connectMethod) key() connectMethodKey { proxyStr := "" targetAddr := cm.targetAddr @@ -1642,14 +1766,6 @@ func (cm *connectMethod) proxyAuth() string { return "" } -// connectMethodKey is the map key version of connectMethod, with a -// stringified proxy URL (or the empty string) instead of a pointer to -// a URL. -type connectMethodKey struct { - proxy, scheme, addr string - onlyH1 bool -} - // A wantConn records state about a wanted connection // (that is, an active call to getConn). // The conn may be gotten by dialing or by finding an idle connection, From c2eb82d09421d8d58ef154c18e9ae1357fcd832b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 22 Aug 2024 18:49:43 +0800 Subject: [PATCH 15/21] WIP(x/http/client): bodyEOFSignal packaging & optimized readLoop() & adjusted timeout logic --- x/net/http/client.go | 6 +- x/net/http/request.go | 4 +- x/net/http/response.go | 1 + x/net/http/transport.go | 384 ++++++++++++++++++++++++++++++---------- 4 files changed, 295 insertions(+), 100 deletions(-) diff --git a/x/net/http/client.go b/x/net/http/client.go index 4fc6e41..6bf72c9 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -307,7 +307,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } // TODO(spongehah) timeout + req.timeoutch = make(chan struct{}, 1) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + sub := deadline.Sub(time.Now()) req.timeout = sub resp, err = rt.RoundTrip(req) @@ -504,7 +506,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) } - cancel := make(chan struct{}) + cancel := make(chan struct{}, 1) req.Cancel = cancel doCancel := func() { @@ -518,7 +520,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi } } - stopTimerCh := make(chan struct{}) + stopTimerCh := make(chan struct{}, 1) var once sync.Once stopTimer = func() { once.Do(func() { diff --git a/x/net/http/request.go b/x/net/http/request.go index cb50936..8a1fb88 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -38,7 +38,9 @@ type Request struct { RemoteAddr string RequestURI string //TLS *tls.ConnectionState - Cancel <-chan struct{} + Cancel <-chan struct{} + timeoutch chan struct{} //optional + Response *Response timeout time.Duration ctx context.Context diff --git a/x/net/http/response.go b/x/net/http/response.go index d2a7dd5..c647f20 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -21,6 +21,7 @@ type Response struct { ContentLength int64 TransferEncoding []string Close bool + Uncompressed bool //Trailer Header Request *Request } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 99003eb..52b7134 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -1,6 +1,7 @@ package http import ( + "compress/gzip" "context" "errors" "fmt" @@ -117,11 +118,6 @@ type connAndTimeoutChan struct { timeoutch chan struct{} } -type freeChan struct { - _ incomparable - freech chan struct{} -} - type readTrackingBody struct { io.ReadCloser didRead bool @@ -141,7 +137,7 @@ func (r *readTrackingBody) Close() error { // setupRewindBody returns a new request with a custom body wrapper // that can report whether the body needs rewinding. // This lets rewindBody avoid an error result when the request -// does not have GetBody but the body hasn't been readRespLineAndHeader at all yet. +// does not have GetBody but the body hasn't been read at all yet. func setupRewindBody(req *Request) *Request { if req.Body == nil || req.Body == NoBody { return req @@ -269,6 +265,32 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool { return true } +// CancelRequest cancels an in-flight request by closing its connection. +// CancelRequest should only be called after RoundTrip has returned. +// +// Deprecated: Use Request.WithContext to create a request with a +// cancelable context instead. CancelRequest cannot cancel HTTP/2 +// requests. +func (t *Transport) CancelRequest(req *Request) { + t.cancelRequest(cancelKey{req}, errRequestCanceled) +} + +// Cancel an in-flight request, recording the error value. +// Returns whether the request was canceled. +func (t *Transport) cancelRequest(key cancelKey, err error) bool { + // This function must not return until the cancel func has completed. + // See: https://golang.org/issue/34658 + t.reqMu.Lock() + defer t.reqMu.Unlock() + cancel := t.reqCanceler[key] + delete(t.reqCanceler, key) + if cancel != nil { + cancel(err) + } + + return cancel != nil +} + // ---------------------------------------------------------- func (t *Transport) RoundTrip(req *Request) (*Response, error) { @@ -451,6 +473,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn + case <-treq.Request.timeoutch: + return nil, fmt.Errorf("request timeout\n") //case <-req.Context().Done(): // return nil, req.Context().Err() case err := <-cancelc: @@ -573,12 +597,8 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers cacheKey: cm.key(), reqch: make(chan requestAndChan, 1), writech: make(chan writeRequest, 1), - cancelch: make(chan freeChan, 1), - timeoutch: make(chan struct{}, 1), closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), - //writech: make(chan writeRequest, 1), - //closech: make(chan struct{}), } //if cm.scheme() == "https" && t.hasCustomTLSDialer() { @@ -675,9 +695,8 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} - if conn.IsCompleted != 1 { - go pconn.readWriteLoop(libuv.DefaultLoop()) - } + go pconn.readWriteLoop(libuv.DefaultLoop()) + return pconn, nil } @@ -699,7 +718,7 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth libuv.InitTimer(loop, &conn.TimeoutTimer) ct := &connAndTimeoutChan{ conn: conn, - timeoutch: pconn.timeoutch, + timeoutch: treq.Request.timeoutch, } (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) @@ -716,14 +735,14 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth var res *net.AddrInfo status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { - close(pconn.timeoutch) + close(treq.Request.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { - close(pconn.timeoutch) + close(treq.Request.timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } @@ -781,7 +800,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.extraHeaders().Set("Connection", "close") } - gone := make(chan struct{}) + gone := make(chan struct{}, 1) defer close(gone) defer func() { @@ -799,7 +818,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). startBytesWritten := pc.nwrite writeErrCh := make(chan error, 1) - pc.writech <- writeRequest{req, writeErrCh} + pc.writech <- writeRequest{req: req, ch: writeErrCh} // Send the request to readWriteLoop(). resc := make(chan responseAndError, 1) @@ -820,13 +839,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() - // Determine whether timeout has occurred - if pc.conn.IsCompleted == 1 { - rc := <-pc.reqch // blocking - // Free the resources - freeResources(nil, nil, nil, nil, pc, rc) - return nil, fmt.Errorf("request timeout\n") - } select { case err := <-writeErrCh: if debugRoundTrip { @@ -855,23 +867,21 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err //case <-respHeaderTimer: case re := <-resc: if (re.res == nil) == (re.err == nil) { + println(1) return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } if debugRoundTrip { + println(2) //req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) } if re.err != nil { + println(3) return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil // TODO(spongehah) cancel(pc.roundTrip) //case <-cancelChan: - case <-pc.timeoutch: - freech := make(chan struct{}, 1) - pc.cancelch <- freeChan{ - freech: freech, - } - <-freech + case <-req.Request.timeoutch: return nil, fmt.Errorf("request timeout\n") } } @@ -884,22 +894,17 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { const debugReadWriteLoop = true // Debug switch provided for developers + if debugReadWriteLoop { + println("readWriteLoop start") + } + // The polling state machine! // Poll all ready tasks and act on them... alive := true var bodyWriter *io.PipeWriter + var rw readWaiter for alive { select { - case fc := <-pc.cancelch: - if debugReadWriteLoop { - println("cancelch") - } - // Free the resources - //freeResources(nil, respBody, bodyWriter, pc.exec, pc, rc) - alive = false - pc.close(errors.New("timeout error")) - close(fc.freech) - return case <-pc.closech: if debugReadWriteLoop { println("closech") @@ -911,18 +916,22 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { loop.Run(libuv.RUN_ONCE) continue } - switch (taskId)(uintptr(task.Userdata())) { + taskId := (taskId)(uintptr(task.Userdata())) + if debugReadWriteLoop { + println(taskId) + } + switch taskId { case write: if debugReadWriteLoop { println("write") } - wc := <-pc.writech // blocking + wr := <-pc.writech // blocking startBytesWritten := pc.nwrite err := checkTaskType(task, write) if err != nil { - wc.ch <- err + wr.ch <- err // Free the resources //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) pc.close(err) @@ -933,7 +942,16 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { task.Free() // Prepare the hyper.Request - hyperReq, err := newHyperRequest(wc.req.Request) + hyperReq, err := newHyperRequest(wr.req.Request) + if err == nil { + // Send it! + sendTask := client.Send(hyperReq) + setTaskId(sendTask, read) + sendRes := pc.exec.Push(sendTask) + if sendRes != hyper.OK { + err = errors.New("failed to send the request") + } + } if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -943,25 +961,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // pc.close() which tears down // connections and causes other // errors. - wc.req.setError(err) + wr.req.setError(err) } if err != nil { if pc.nwrite == startBytesWritten { err = nothingWrittenError{err} } - wc.ch <- err - // Free the resources - //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) - pc.close(err) - return - } - - // Send it! - sendTask := client.Send(hyperReq) - setTaskId(sendTask, readRespLineAndHeader) - sendRes := pc.exec.Push(sendTask) - if sendRes != hyper.OK { - wc.ch <- err + //pc.writeErrCh <- err // to the body reader, which might recycle us + wr.ch <- err // to the roundTrip function // Free the resources //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) pc.close(err) @@ -970,10 +977,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // For this example, no longer need the client client.Free() - case readRespLineAndHeader: if debugReadWriteLoop { - println("readRespLineAndHeader") + println("write end") + } + case read: + if debugReadWriteLoop { + println("read") } + rc := <-pc.reqch // blocking closeErr := errReadLoopExiting // default value, if not changed below @@ -997,6 +1008,12 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // return true //} + // eofc is used to block caller goroutines reading from Response.Body + // at EOF until this goroutines has (potentially) added the connection + // back to the idle pool. + eofc := make(chan struct{}, 1) + defer close(eofc) // unblock reader on errors + // Read this once, before loop starts. (to avoid races in tests) testHookMu.Lock() testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead @@ -1010,7 +1027,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } pc.mu.Unlock() - err := checkTaskType(task, readRespLineAndHeader) + err := checkTaskType(task, read) // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() @@ -1038,8 +1055,6 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } // Response has been returned, stop the timer - pc.conn.IsCompleted = 1 - // Stop the timer if rc.req.timeout > 0 { pc.conn.TimeoutTimer.Stop() (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) @@ -1091,24 +1106,71 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { continue } + waitForBodyRead := make(chan bool, 2) + body := &bodyEOFSignal{ + body: resp.Body, + earlyCloseFn: func() error { + waitForBodyRead <- false + <-eofc // will be closed by deferred call at the end of the function + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + waitForBodyRead <- isEOF + if isEOF { + <-eofc // see comment above eofc declaration + } else if err != nil { + if cerr := pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + resp.Body = body + + // TODO(spongehah) gzip fail + if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + resp.Body = &gzipReader{body: body} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true + } + + rw.waitForBodyRead = waitForBodyRead + rw.rc = rc + rw.eofc = eofc bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) - setTaskId(bodyForeachTask, readRespBody) + setTaskId(bodyForeachTask, readDone) pc.exec.Push(bodyForeachTask) + // TODO(spongehah) select blocking + //select { + //case rc.ch <- responseAndError{res: resp}: + //case <-rc.callerGone: + // return + //} rc.ch <- responseAndError{res: resp} // No longer need the response hyperResp.Free() - case readRespBody: + if debugReadWriteLoop { + println("read end") + } + + //pc.t.replaceReqCanceler(rc.cancelKey, nil) + //eofc <- struct{}{} + case readDone: // A background task of reading the response body is completed if debugReadWriteLoop { - println("readRespBody") + println("readDone") } - err := checkTaskType(task, readRespBody) + err := checkTaskType(task, readDone) if err != nil { fmt.Println(err) pc.close(err) - return + alive = false } if task.Type() != hyper.TaskEmpty { @@ -1121,6 +1183,39 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // free the task task.Free() bodyWriter.Close() + + // Before looping back to the top of this function and peeking on + // the bufio.Reader, wait for the caller goroutine to finish + // reading the response body. (or for cancellation or death) + rc := rw.rc + select { + //case bodyEOF := <-rw.waitForBodyRead: + case <-rw.waitForBodyRead: + //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + // TODO(spongehah) ConnPool(readWriteLoop) + //alive = alive && + // bodyEOF && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + rw.eofc <- struct{}{} + // TODO(spongehah) cancel(pc.readWriteLoop) + //case <-rc.req.Cancel: + // alive = false + // pc.t.CancelRequest(rc.req) + //case <-rc.req.Context().Done(): + // alive = false + // pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + case <-pc.closech: + alive = false + } + + testHookReadLoopBeforeNextRead() + if debugReadWriteLoop { + println("readDone end") + } case notSet: // A background task for hyper_client completed... task.Free() @@ -1136,7 +1231,6 @@ type connData struct { ConnectReq libuv.Connect ReadBuf libuv.Buf TimeoutTimer libuv.Timer - IsCompleted int ReadBufFilled uintptr ReadWaker *hyper.Waker WriteWaker *hyper.Waker @@ -1175,12 +1269,10 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { } // onRead is the libuv callback for reading from a socket -// This callback function is called when data is available to be readRespLineAndHeader +// This callback function is called when data is available to be read func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - //conn := (*ConnData)(stream.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data // If data was read (nread > 0) if nread > 0 { @@ -1196,7 +1288,7 @@ func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { } } -// readCallBack readRespLineAndHeader callback function for Hyper library +// readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) conn := (*connData)(userdata) @@ -1236,8 +1328,6 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin func onWrite(req *libuv.Write, status c.Int) { // Get the connection data associated with the write request conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data // If there's a pending write waker if conn.WriteWaker != nil { @@ -1254,11 +1344,9 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui conn := (*connData)(userdata) // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) - //req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) req := &libuv.Write{} // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - //req.Data = c.Pointer(conn) // Perform the asynchronous write operation ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) @@ -1282,15 +1370,12 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui // onTimeout is the libuv callback for a timeout func onTimeout(handle *libuv.Timer) { ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) - if ct.conn.IsCompleted != 1 { - ct.conn.IsCompleted = 1 - ct.timeoutch <- struct{}{} - } + close(ct.timeoutch) // Close the timer (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) } -// newIoWithConnReadWrite creates a new IO with readRespLineAndHeader and write callbacks +// newIoWithConnReadWrite creates a new IO with read and write callbacks func newIoWithConnReadWrite(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) @@ -1305,10 +1390,16 @@ type taskId c.Int const ( notSet taskId = iota write - readRespLineAndHeader - readRespBody + read + readDone ) +type readWaiter struct { + rc requestAndChan + waitForBodyRead chan bool + eofc chan struct{} +} + // setTaskId Set taskId to the task's userdata as a unique identifier func setTaskId(task *hyper.Task, userData taskId) { var data = userData @@ -1327,9 +1418,9 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case readRespLineAndHeader: + case read: if task.Type() == hyper.TaskError { - c.Printf(c.Str("send task error!\n")) + c.Printf(c.Str("write task error!\n")) return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskResponse { @@ -1337,9 +1428,9 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case readRespBody: + case readDone: if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) + c.Printf(c.Str("read error!\n")) return fail((*hyper.Error)(task.Value())) } return nil @@ -1391,8 +1482,6 @@ func closeChannels(rc requestAndChan, pc *persistConn) { // Closing the channel close(rc.ch) close(pc.reqch) - close(pc.timeoutch) - close(pc.cancelch) } // freeConnData frees the connection data @@ -1475,7 +1564,7 @@ func (nwe nothingWrittenError) Unwrap() error { } // transportReadFromServerError is used by Transport.readLoop when the -// 1 byte peek readRespLineAndHeader fails and we're actually anticipating a response. +// 1 byte peek read fails and we're actually anticipating a response. // Usually this is just due to the inherent keep-alive shut down race, // where the server closed the connection at the same time the client // wrote. The underlying err field is usually io.EOF or some @@ -1546,14 +1635,11 @@ type persistConn struct { cacheKey connectMethodKey conn *connData nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; readRespLineAndHeader by readWriteLoop - writech chan writeRequest // written by roundTrip; readRespLineAndHeader by writeLoop(Already merged into reqch) + reqch chan requestAndChan // written by roundTrip; read by readWriteLoop + writech chan writeRequest // written by roundTrip; read by writeLoop(Already merged into reqch) closech chan struct{} // closed when conn closed writeLoopDone chan struct{} // closed when write loop ends - cancelch chan freeChan - timeoutch chan struct{} - isProxy bool mu sync.Mutex // guards following fields numExpectedResponses int @@ -1900,3 +1986,107 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { cleaned = true } } + +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. fn should return the new error to return from Read or Close. +// +// If earlyCloseFn is non-nil and Close is called before io.EOF is +// seen, earlyCloseFn is called instead of fn, and its return value is +// the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) error // err will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errReadOnClosedResBody + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + err = es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + return es.condfn(err) +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) error { + if es.fn == nil { + return err + } + err = es.fn(err) + es.fn = nil + return err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + _ incomparable + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) + } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} From a68bc29fada2c1b6789a890769641191fa01622c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 23 Aug 2024 18:28:37 +0800 Subject: [PATCH 16/21] WIP(x/http/client): Use body to wrap readCloser & optimize req.write code --- x/net/http/client.go | 8 +- x/net/http/http.go | 11 ++ x/net/http/request.go | 386 ++++++++++++++++++++++++++-------------- x/net/http/response.go | 6 +- x/net/http/server.go | 12 ++ x/net/http/transfer.go | 334 ++++++++++++++++++++++++++-------- x/net/http/transport.go | 326 ++++++++++++++------------------- x/net/http/util.go | 98 ++++++++++ x/net/ipsock.go | 79 +++++++- x/net/net.go | 20 +++ x/net/parse.go | 12 ++ 11 files changed, 882 insertions(+), 410 deletions(-) create mode 100644 x/net/http/server.go create mode 100644 x/net/net.go create mode 100644 x/net/parse.go diff --git a/x/net/http/client.go b/x/net/http/client.go index 6bf72c9..bf1bfd4 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -139,9 +139,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { resp.closeBody() return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) } - // TODO(spongehah) redirect: Why use host := "" - //host := "" - host := u.Host + host := "" if req.Host != "" && req.Host != req.URL.Host { // If the caller specified a custom Host header and the @@ -239,7 +237,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // didTimeout is non-nil only if err != nil. func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { - // TODO(spongehah) cookie + // TODO(spongehah) cookie(c.send) if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) @@ -306,7 +304,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) timeout + // TODO(spongehah) timeout(send) req.timeoutch = make(chan struct{}, 1) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) diff --git a/x/net/http/http.go b/x/net/http/http.go index f668906..2f5d5ab 100644 --- a/x/net/http/http.go +++ b/x/net/http/http.go @@ -13,6 +13,17 @@ func isNotToken(r rune) bool { return !IsTokenRune(r) } +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} + // removeEmptyPort strips the empty port in ":port" to "" // as mandated by RFC 3986 Section 6.2.3. func removeEmptyPort(host string) string { diff --git a/x/net/http/request.go b/x/net/http/request.go index 8a1fb88..be81f0e 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -46,7 +46,7 @@ type Request struct { ctx context.Context } -var defaultChunkSize uintptr = 8192 +const defaultChunkSize = 8192 // NewRequest wraps NewRequestWithContext using context.Background. func NewRequest(method, url string, body io.Reader) (*Request, error) { @@ -128,6 +128,7 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R r := bytes.NewReader(buf) return io.NopCloser(r), nil } + case *bytes.Reader: req.ContentLength = int64(v.Len()) snapshot := *v @@ -166,122 +167,36 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R return req, nil } -func printInformational(userdata c.Pointer, resp *hyper.Response) { - status := resp.Status() - fmt.Println("Informational (1xx): ", status) -} - -type postReq struct { - req *Request - buf []byte -} - -func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - req := (*postReq)(userdata) - n, err := req.req.Body.Read(req.buf) - if err != nil { - if err == io.EOF { - *chunk = nil - return hyper.PollReady - } - fmt.Println("error reading upload file: ", err) - return hyper.PollError - } - if n > 0 { - *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) - return hyper.PollReady - } - if n == 0 { - *chunk = nil - return hyper.PollReady - } - - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) - return hyper.PollError -} - -func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - type buf struct { - data *uint8 - len uintptr - Unused [16]byte - } - req := (*postReq)(userdata) - buffer := &buf{ - data: &req.buf[0], - len: uintptr(len(req.buf)), - } - - *chunk = (*hyper.Buf)(c.Pointer(buffer)) - n, err := req.req.Body.Read(req.buf) - if err != nil { - if err == io.EOF { - *chunk = nil - return hyper.PollReady - } - fmt.Println("error reading upload file: ", err) - return hyper.PollError - } - if n > 0 { - return hyper.PollReady - } - if n == 0 { - *chunk = nil - return hyper.PollReady - } - - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) - return hyper.PollError -} - -func newHyperRequest(req *Request) (*hyper.Request, error) { - host := req.Host - uri := req.URL.RequestURI() - method := req.Method - // Prepare the request - hyperReq := hyper.NewRequest() - // Set the request method and uri - if hyperReq.SetMethod(&[]byte(method)[0], c.Strlen(c.AllocaCStr(method))) != hyper.OK { - return nil, fmt.Errorf("error setting method %s\n", method) - } - if hyperReq.SetURI(&[]byte(uri)[0], c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - return nil, fmt.Errorf("error setting uri %s\n", uri) - } - // Set the request headers - reqHeaders := hyperReq.Headers() - if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting header: Host: %s\n", host) - } - - if method == "POST" && req.Body != nil { - // 100-continue - if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() { - hyperReq.OnInformational(printInformational, nil) - } - - hyperReqBody := hyper.NewBody() - reqData := &postReq{ - req: req, - buf: make([]byte, 3), - } - hyperReqBody.SetUserdata(c.Pointer(reqData)) - hyperReqBody.SetDataFunc(setPostData) - hyperReq.SetBody(hyperReqBody) - } - - // Add user-defined request headers to hyper.Request - err := req.setHeaders(hyperReq) - if err != nil { - return nil, err - } - - return hyperReq, nil -} +//func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { +// req := (*postReq)(userdata) +// buf := req.hyperBuf.Bytes() +// len := req.hyperBuf.Len() +// n, err := req.req.Body.Read(unsafe.Slice(buf, len)) +// if err != nil { +// if err == io.EOF { +// *chunk = nil +// return hyper.PollReady +// } +// fmt.Println("error reading upload file: ", err) +// return hyper.PollError +// } +// if n > 0 { +// *chunk = req.hyperBuf +// return hyper.PollReady +// } +// if n == 0 { +// *chunk = nil +// return hyper.PollReady +// } +// +// fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) +// return hyper.PollError +//} // setHeaders sets the headers of the request -func (req *Request) setHeaders(hyperReq *hyper.Request) error { +func (r *Request) setHeaders(hyperReq *hyper.Request) error { headers := hyperReq.Headers() - for key, values := range req.Header { + for key, values := range r.Header { valueLen := len(values) if valueLen > 1 { for _, value := range values { @@ -318,23 +233,6 @@ func (r *Request) closeBody() error { return r.Body.Close() } -func validMethod(method string) bool { - /* - Method = "OPTIONS" ; Section 9.2 - | "GET" ; Section 9.3 - | "HEAD" ; Section 9.4 - | "POST" ; Section 9.5 - | "PUT" ; Section 9.6 - | "DELETE" ; Section 9.7 - | "TRACE" ; Section 9.8 - | "CONNECT" ; Section 9.9 - | extension-method - extension-method = token - token = 1* - */ - return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 -} - // Context returns the request's context. To change the context, use // Clone or WithContext. // @@ -387,6 +285,215 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + +// extraHeaders may be nil +// waitForContinue may be nil +// always closes body +func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.ClientConn, exec *hyper.Executor) (err error) { + //trace := httptrace.ContextClientTrace(r.Context()) + //if trace != nil && trace.WroteRequest != nil { + // defer func() { + // trace.WroteRequest(httptrace.WroteRequestInfo{ + // Err: err, + // }) + // }() + //} + + //closed := false + //defer func() { + // if closed { + // return + // } + // if closeErr := r.closeBody(); closeErr != nil && err == nil { + // err = closeErr + // } + //}() + + // Prepare the hyper.Request + hyperReq, err := r.newHyperRequest(usingProxy, extraHeader) + if err != nil { + return err + } + // Send it! + sendTask := client.Send(hyperReq) + setTaskId(sendTask, read) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + err = errors.New("failed to send the request") + } + return err +} + +func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.Request, error) { + // Find the target host. Prefer the Host: header, but if that + // is not given, use the host from the request URL. + // + // Clean the host, in case it arrives with unexpected stuff in it. + host := r.Host + if host == "" { + if r.URL == nil { + return nil, errMissingHost + } + host = r.URL.Host + } + host, err := PunycodeHostPort(host) + if err != nil { + return nil, err + } + // Validate that the Host header is a valid header in general, + // but don't validate the host itself. This is sufficient to avoid + // header or request smuggling via the Host field. + // The server can (and will, if it's a net/http server) reject + // the request if it doesn't consider the host valid. + if !ValidHostHeader(host) { + // Historically, we would truncate the Host header after '/' or ' '. + // Some users have relied on this truncation to convert a network + // address such as Unix domain socket path into a valid, ignored + // Host header (see https://go.dev/issue/61431). + // + // We don't preserve the truncation, because sending an altered + // header field opens a smuggling vector. Instead, zero out the + // Host header entirely if it isn't valid. (An empty Host is valid; + // see RFC 9112 Section 3.2.) + // + // Return an error if we're sending to a proxy, since the proxy + // probably can't do anything useful with an empty Host header. + if !usingProxy { + host = "" + } else { + return nil, errors.New("http: invalid Host header") + } + } + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + host = removeZone(host) + + ruri := r.URL.RequestURI() + if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { + ruri = r.URL.Scheme + "://" + host + ruri + } else if r.Method == "CONNECT" && r.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + if r.URL.Opaque != "" { + ruri = r.URL.Opaque + } + } + if stringContainsCTLByte(ruri) { + return nil, errors.New("net/http: can't write control character in Request.URL") + } + + + + + // Prepare the hyper request + hyperReq := hyper.NewRequest() + + // Set the request line, default HTTP/1.1 + if hyperReq.SetMethod(&[]byte(r.Method)[0], c.Strlen(c.AllocaCStr(r.Method))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", r.Method) + } + if hyperReq.SetURI(&[]byte(ruri)[0], c.Strlen(c.AllocaCStr(ruri))) != hyper.OK { + return nil, fmt.Errorf("error setting uri %s\n", ruri) + } + if hyperReq.SetVersion(c.Int(hyper.HTTPVersion11)) != hyper.OK { + return nil, fmt.Errorf("error setting httpversion %s\n", "HTTP/1.1") + } + + // Set the request headers + reqHeaders := hyperReq.Headers() + if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { + return nil, fmt.Errorf("error setting header: Host: %s\n", host) + } + err = r.setHeaders(hyperReq) + if err != nil { + return nil, err + } + + if r.Body != nil { + // 100-continue + if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { + hyperReq.OnInformational(printInformational, nil) + } + + hyperReqBody := hyper.NewBody() + //buf := make([]byte, 2) + //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(2)) + reqData := &postReq{ + req: r, + buf: make([]byte, defaultChunkSize), + //hyperBuf: hyperBuf, + } + hyperReqBody.SetUserdata(c.Pointer(reqData)) + hyperReqBody.SetDataFunc(setPostData) + //hyperReqBody.SetDataFunc(setPostDataNoCopy) + hyperReq.SetBody(hyperReqBody) + } + + return hyperReq, nil +} + +func printInformational(userdata c.Pointer, resp *hyper.Response) { + status := resp.Status() + fmt.Println("Informational (1xx): ", status) +} + +type postReq struct { + req *Request + buf []byte + //hyperBuf *hyper.Buf +} + +func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + req := (*postReq)(userdata) + n, err := req.req.Body.Read(req.buf) + if err != nil { + if err == io.EOF { + println("EOF") + *chunk = nil + req.req.Body.Close() + return hyper.PollReady + } + fmt.Println("error reading request body: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) + return hyper.PollReady + } + if n == 0 { + println("n == 0") + *chunk = nil + req.req.Body.Close() + return hyper.PollReady + } + + req.req.Body.Close() + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // @@ -446,3 +553,20 @@ func idnaASCII(v string) (string, error) { } return idna.Lookup.ToASCII(v) } + +// removeZone removes IPv6 zone identifier from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} diff --git a/x/net/http/response.go b/x/net/http/response.go index c647f20..8151ac2 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -32,7 +32,7 @@ func (r *Response) closeBody() { } } -func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { +func ReadResponse(r *io.PipeReader, req *Request, hyperResp *hyper.Response) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), @@ -42,7 +42,7 @@ func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { fixPragmaCacheControl(req.Header) - err := readTransfer(resp) + err := readTransfer(resp, r) if err != nil { return nil, err } @@ -54,9 +54,9 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { rp := hyperResp.ReasonPhrase() rpLen := hyperResp.ReasonPhraseLen() + // Parse the first line of the response. resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + c.GoString((*int8)(c.Pointer(rp)), rpLen) resp.StatusCode = int(hyperResp.Status()) - version := int(hyperResp.Version()) resp.ProtoMajor, resp.ProtoMinor = splitTwoDigitNumber(version) resp.Proto = fmt.Sprintf("HTTP/%d.%d", resp.ProtoMajor, resp.ProtoMinor) diff --git a/x/net/http/server.go b/x/net/http/server.go new file mode 100644 index 0000000..f38cbd0 --- /dev/null +++ b/x/net/http/server.go @@ -0,0 +1,12 @@ +package http + +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 823ef7d..cf96f84 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -1,11 +1,13 @@ package http import ( + "errors" "fmt" "io" "net/textproto" "strconv" "strings" + "sync" "unicode/utf8" ) @@ -24,13 +26,49 @@ type transferReader struct { Trailer Header } -// unsupportedTEError reports unsupported transfer-encodings. -type unsupportedTEError struct { - err string +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} + } + if !EqualFold(raw[0], "chunked") { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil } -func (uste *unsupportedTEError) Error() string { - return uste.err +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) } // NoBody is an io.ReadCloser with no bytes. Read always returns EOF @@ -45,7 +83,17 @@ func (noBody) Read([]byte) (int, error) { return 0, io.EOF } func (noBody) Close() error { return nil } func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } -func readTransfer(msg any) (err error) { +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// msg is *Request or *Response. +func readTransfer(msg any, r *io.PipeReader) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -74,6 +122,11 @@ func readTransfer(msg any) (err error) { panic("unexpected type") } + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + // Transfer-Encoding: chunked, and overriding Content-Length. if err = t.parseTransferEncoding(); err != nil { return err @@ -93,6 +146,7 @@ func readTransfer(msg any) (err error) { t.ContentLength = realLength } + // TODO(spongehah) Trailer(readTransfer) // Trailer //t.Trailer, err = fixTrailer(t.Header, t.Chunked) @@ -109,40 +163,42 @@ func readTransfer(msg any) (err error) { // Prepare body reader. ContentLength < 0 means chunked encoding // or close connection when finished, since multipart is not supported yet - //switch { - //case t.Chunked: - // if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { - // t.Body = NoBody - // } else { - // t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} - // } - //case realLength == 0: - // t.Body = NoBody - //case realLength > 0: - // t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} - //default: - // // realLength < 0, i.e. "Content-Length" not mentioned in header - // if t.Close { - // // Close semantics (i.e. HTTP/1.0) - // t.Body = &body{src: r, closing: t.Close} - // } else { - // // Persistent connection (i.e. HTTP/1.1) - // t.Body = NoBody - // } - //} + switch { + case t.Chunked: + if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + t.Body = NoBody + } else { + // TODO(spongehah) ChunkReader(readTransfer) + //t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: r, hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = NoBody + case realLength > 0: + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + default: + // realLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{src: r, closing: t.Close} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = NoBody + } + } // Unify output switch rr := msg.(type) { case *Request: - //rr.Body = t.Body - //rr.ContentLength = t.ContentLength - //if t.Chunked { - // rr.TransferEncoding = []string{"chunked"} - //} + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } rr.Close = t.Close //rr.Trailer = t.Trailer case *Response: - //rr.Body = t.Body + rr.Body = t.Body rr.ContentLength = t.ContentLength if t.Chunked { rr.TransferEncoding = []string{"chunked"} @@ -154,51 +210,6 @@ func readTransfer(msg any) (err error) { return nil } -// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. -func (t *transferReader) parseTransferEncoding() error { - raw, present := t.Header["Transfer-Encoding"] - if !present { - return nil - } - delete(t.Header, "Transfer-Encoding") - - // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. - if !t.protoAtLeast(1, 1) { - return nil - } - - // Like nginx, we only support a single Transfer-Encoding header field, and - // only if set to "chunked". This is one of the most security sensitive - // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it - // strict and simple. - if len(raw) != 1 { - return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} - } - if !EqualFold(raw[0], "chunked") { - return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} - } - - // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field - // in any message that contains a Transfer-Encoding header field." - // - // but also: "If a message is received with both a Transfer-Encoding and a - // Content-Length header field, the Transfer-Encoding overrides the - // Content-Length. Such a message might indicate an attempt to perform - // request smuggling (Section 9.5) or response splitting (Section 9.4) and - // ought to be handled as an error. A sender MUST remove the received - // Content-Length field prior to forwarding such a message downstream." - // - // Reportedly, these appear in the wild. - delete(t.Header, "Content-Length") - - t.Chunked = true - return nil -} - -func (t *transferReader) protoAtLeast(m, n int) bool { - return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) -} - // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. @@ -329,6 +340,173 @@ func fixTrailer(header Header, chunked bool) (Header, error) { return trailer, nil } +// body turns a Reader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + src io.Reader + hdr any // non-nil (Response or Request) value means read trailer + //r *bufio.Reader // underlying wire-format reader for the trailer + r io.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + doEarlyClose bool // whether Close should stop early + + mu sync.Mutex // guards following, and calls to Read and Close + sawEOF bool + closed bool + earlyClose bool // Close called and we didn't read to the end of src + onHitEOF func() // if non-nil, func to call when EOF is Read +} + +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") + +func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + if b.sawEOF { + return 0, io.EOF + } + n, err = b.src.Read(p) + + if err == io.EOF { + b.sawEOF = true + // Chunked case. Read the trailer. + if b.hdr != nil { + // TODO(spongehah) Trailer(b.readLocked) + //if e := b.readTrailer(); e != nil { + // err = e + // // Something went wrong in the trailer, we must not allow any + // // further reads of any kind to succeed from body, nor any + // // subsequent requests on the server connection. See + // // golang.org/issue/12027 + // b.sawEOF = false + // b.closed = true + //} + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } + } + } + + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + b.sawEOF = true + } + } + + if b.sawEOF && b.onHitEOF != nil { + b.onHitEOF() + } + + return n, err +} + +// unreadDataSizeLocked returns the number of bytes of unread input. +// It returns -1 if unknown. +// b.mu must be held. +func (b *body) unreadDataSizeLocked() int64 { + if lr, ok := b.src.(*io.LimitedReader); ok { + return lr.N + } + return -1 +} + +func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil + } + var err error + switch { + case b.sawEOF: + // Already saw EOF, so no need going to look for it. + case b.hdr == nil && b.closing: + // no trailer and closing the connection next. + // no point in reading to EOF. + case b.doEarlyClose: + // Read up to maxPostHandlerReadBytes bytes of the body, looking + // for EOF (and trailers), so we can re-use this connection. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > maxPostHandlerReadBytes { + // There was a declared Content-Length, and we have more bytes remaining + // than our maxPostHandlerReadBytes tolerance. So, give up. + b.earlyClose = true + } else { + var n int64 + // Consume the body, or, which will also lead to us reading + // the trailer headers after the body, if present. + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + if err == io.EOF { + err = nil + } + if n == maxPostHandlerReadBytes { + b.earlyClose = true + } + } + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(io.Discard, bodyLocked{b}) + } + b.closed = true + return err +} + +// bodyLocked is an io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + +func (b *body) didEarlyClose() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.earlyClose +} + +// bodyRemains reports whether future Read calls might +// yield data. +func (b *body) bodyRemains() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.sawEOF +} + +func (b *body) registerOnHitEOF(fn func()) { + b.mu.Lock() + defer b.mu.Unlock() + b.onHitEOF = fn +} + // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. func foreachHeaderElement(v string, fn func(string)) { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 52b7134..fe3efc3 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log" "net/url" "sync" "sync/atomic" @@ -13,9 +14,9 @@ import ( "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/libuv" - "github.com/goplus/llgo/c/net" + cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - xnet "github.com/goplus/llgo/x/net" + "github.com/goplus/llgoexamples/x/net" "github.com/goplus/llgoexamples/rust/hyper" ) @@ -204,6 +205,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } + // TODO(spongehah) cm.treq(connectMethod) cm.treq = treq cm.onlyH1 = treq.requiresHTTP1() return cm, err @@ -352,7 +354,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout: because of that ctx not initialized ( initialized in setRequestCancel() ) + // TODO(spongehah) timeout(t.RoundTrip): because of that ctx not initialized ( initialized in setRequestCancel() ) //select { //case <-ctx.Done(): // req.closeBody() @@ -394,7 +396,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) Retry & ConnPool + // TODO(spongehah) Retry & ConnPool(t.RoundTrip) return nil, err } } @@ -421,7 +423,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } }() - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(t.getConn) //// Queue for idle connection. //if delivered := t.queueForIdleConn(w); delivered { // pc := w.pc @@ -524,7 +526,7 @@ func (t *Transport) dialConnFor(w *wantConn) { pc, err := t.dialConn(w.ctx, w.cm) w.tryDeliver(pc, err) - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(t.dialConnFor) //delivered := w.tryDeliver(pc, err) // Handle undelivered or shareable connections //if err == nil && (!delivered || pc.alt != nil) { @@ -601,6 +603,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers writeLoopDone: make(chan struct{}, 1), } + //trace := httptrace.ContextClientTrace(ctx) + //wrapErr := func(err error) error { + // if cm.proxyURL != nil { + // // Return a typed error, per Issue 16997 + // return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} + // } + // return err + //} + //if cm.scheme() == "https" && t.hasCustomTLSDialer() { // var err error // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) @@ -628,7 +639,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} else { //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(ctx, pconn, cm) + conn, err := t.dial(ctx, cm) if err != nil { return nil, err } @@ -642,9 +653,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // Prepare client options opts := hyper.NewClientConnOptions() opts.Exec(exec) - pconn.io = hyperIo pconn.exec = exec - pconn.opts = opts // send the handshake handshakeTask := hyper.Handshake(hyperIo, opts) setTaskId(handshakeTask, write) @@ -662,7 +671,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers //} //} - // TODO(spongehah) Proxy(https/sock5) + // TODO(spongehah) Proxy(https/sock5)(t.dialConn) // Proxy setup. switch { case cm.proxyURL == nil: @@ -700,7 +709,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return pconn, nil } -func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMethod) (*connData, error) { +func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, error) { treq := cm.treq host := treq.URL.Hostname() port := treq.URL.Port() @@ -724,16 +733,17 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) } + libuv.InitTcp(loop, &conn.TcpHandle) libuv.InitTcp(loop, &conn.TcpHandle) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) - var hints net.AddrInfo + var hints cnet.AddrInfo c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) hints.Family = syscall.AF_UNSPEC hints.SockType = syscall.SOCK_STREAM - var res *net.AddrInfo - status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) + var res *cnet.AddrInfo + status := cnet.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { close(treq.Request.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") @@ -746,14 +756,14 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } - net.Freeaddrinfo(res) + cnet.Freeaddrinfo(res) return conn, nil } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(pc.roundTrip) //pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -816,7 +826,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // request body. // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). - startBytesWritten := pc.nwrite + startBytesWritten := pc.conn.nwrite writeErrCh := make(chan error, 1) pc.writech <- writeRequest{req: req, ch: writeErrCh} @@ -890,8 +900,42 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { + // writeLoop related defer close(pc.writeLoopDone) + // readLoop related + closeErr := errReadLoopExiting // default value, if not changed below + defer func() { + pc.close(closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //pc.t.removeIdleConn(pc) + }() + + //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + // if err := pc.t.tryPutIdleConn(pc); err != nil { + // closeErr = err + // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + // } + // return false + // } + // if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + // } + // return true + //} + + // eofc is used to block caller goroutines reading from Response.Body + // at EOF until this goroutines has (potentially) added the connection + // back to the idle pool. + eofc := make(chan struct{}, 1) + defer close(eofc) // unblock reader on errors + + // Read this once, before loop starts. (to avoid races in tests) + testHookMu.Lock() + testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + testHookMu.Unlock() + const debugReadWriteLoop = true // Debug switch provided for developers if debugReadWriteLoop { @@ -918,7 +962,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } taskId := (taskId)(uintptr(task.Userdata())) if debugReadWriteLoop { - println(taskId) + println("taskId: ", taskId) } switch taskId { case write: @@ -927,31 +971,16 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } wr := <-pc.writech // blocking - startBytesWritten := pc.nwrite - + startBytesWritten := pc.conn.nwrite err := checkTaskType(task, write) - if err != nil { - wr.ch <- err - // Free the resources - //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) - pc.close(err) - return - } - client := (*hyper.ClientConn)(task.Value()) task.Free() - - // Prepare the hyper.Request - hyperReq, err := newHyperRequest(wr.req.Request) if err == nil { - // Send it! - sendTask := client.Send(hyperReq) - setTaskId(sendTask, read) - sendRes := pc.exec.Push(sendTask) - if sendRes != hyper.OK { - err = errors.New("failed to send the request") - } + // TODO(spongehah) Proxy(writeLoop) + err = wr.req.Request.write(pc.isProxy, wr.req.extra, client, pc.exec) } + // For this request, no longer need the client + client.Free() if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -964,19 +993,15 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { wr.req.setError(err) } if err != nil { - if pc.nwrite == startBytesWritten { + if pc.conn.nwrite == startBytesWritten { err = nothingWrittenError{err} } //pc.writeErrCh <- err // to the body reader, which might recycle us wr.ch <- err // to the roundTrip function - // Free the resources - //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) pc.close(err) return } - // For this example, no longer need the client - client.Free() if debugReadWriteLoop { println("write end") } @@ -985,39 +1010,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { println("read") } - rc := <-pc.reqch // blocking - - closeErr := errReadLoopExiting // default value, if not changed below - defer func() { - pc.close(closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //pc.t.removeIdleConn(pc) - }() - - //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { - // if err := pc.t.tryPutIdleConn(pc); err != nil { - // closeErr = err - // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - // } - // return false - // } - // if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - // } - // return true - //} - - // eofc is used to block caller goroutines reading from Response.Body - // at EOF until this goroutines has (potentially) added the connection - // back to the idle pool. - eofc := make(chan struct{}, 1) - defer close(eofc) // unblock reader on errors - - // Read this once, before loop starts. (to avoid races in tests) - testHookMu.Lock() - testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - testHookMu.Unlock() + err := checkTaskType(task, read) pc.mu.Lock() if pc.numExpectedResponses == 0 { @@ -1027,7 +1020,9 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } pc.mu.Unlock() - err := checkTaskType(task, read) + rc := <-pc.reqch // blocking + //trace := httptrace.ContextClientTrace(rc.req.Context()) + // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() @@ -1035,22 +1030,24 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { var resp *Response var respBody *hyper.Body if err == nil { - resp, err = ReadResponse(hyperResp, rc.req) + var pr *io.PipeReader + pr, bodyWriter = io.Pipe() + resp, err = ReadResponse(pr, rc.req, hyperResp) respBody = hyperResp.Body() - resp.Body, bodyWriter = io.Pipe() } else { err = transportReadFromServerError{err} closeErr = err } + // No longer need the response + hyperResp.Free() + if err != nil { select { case rc.ch <- responseAndError{err: err}: case <-rc.callerGone: return } - // Free the resources - freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) return } @@ -1129,7 +1126,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } resp.Body = body - // TODO(spongehah) gzip fail + // TODO(spongehah) gzip fail(readWriteLoop) if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { resp.Body = &gzipReader{body: body} resp.Header.Del("Content-Encoding") @@ -1140,12 +1137,11 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { rw.waitForBodyRead = waitForBodyRead rw.rc = rc - rw.eofc = eofc bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) setTaskId(bodyForeachTask, readDone) pc.exec.Push(bodyForeachTask) - // TODO(spongehah) select blocking + // TODO(spongehah) select blocking(readWriteLoop) //select { //case rc.ch <- responseAndError{res: resp}: //case <-rc.callerGone: @@ -1153,46 +1149,31 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} rc.ch <- responseAndError{res: resp} - // No longer need the response - hyperResp.Free() if debugReadWriteLoop { println("read end") } - - //pc.t.replaceReqCanceler(rc.cancelKey, nil) - //eofc <- struct{}{} case readDone: // A background task of reading the response body is completed if debugReadWriteLoop { println("readDone") } - err := checkTaskType(task, readDone) - if err != nil { - fmt.Println(err) - pc.close(err) - alive = false - } - - if task.Type() != hyper.TaskEmpty { - err = errors.New("unexpected task type\n") - fmt.Println(err) - pc.close(err) - return + if bodyWriter != nil { + bodyWriter.Close() } + checkTaskType(task, readDone) + hyperBodyEOF := task.Type() == hyper.TaskEmpty // free the task task.Free() - bodyWriter.Close() // Before looping back to the top of this function and peeking on // the bufio.Reader, wait for the caller goroutine to finish // reading the response body. (or for cancellation or death) - rc := rw.rc select { - //case bodyEOF := <-rw.waitForBodyRead: - case <-rw.waitForBodyRead: + case bodyEOF := <-rw.waitForBodyRead: + bodyEOF = bodyEOF && hyperBodyEOF //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool - pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + pc.t.replaceReqCanceler(rw.rc.cancelKey, nil) // before pc might return to idle pool // TODO(spongehah) ConnPool(readWriteLoop) //alive = alive && // bodyEOF && @@ -1200,14 +1181,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // pc.wroteRequest() && // replaced && tryPutIdleConn(trace) - rw.eofc <- struct{}{} + eofc <- struct{}{} // TODO(spongehah) cancel(pc.readWriteLoop) - //case <-rc.req.Cancel: + //case <-rw.rc.req.Cancel: // alive = false - // pc.t.CancelRequest(rc.req) - //case <-rc.req.Context().Done(): + // pc.t.CancelRequest(rw.rc.req) + //case <-rw.rc.req.Context().Done(): // alive = false - // pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -1232,12 +1213,25 @@ type connData struct { ReadBuf libuv.Buf TimeoutTimer libuv.Timer ReadBufFilled uintptr + nwrite int64 // bytes written(Replaced from persistConn's nwrite) ReadWaker *hyper.Waker WriteWaker *hyper.Waker } func (conn *connData) Close() error { - freeConnData(conn) + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) return nil } @@ -1352,6 +1346,7 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) // If the write operation was successfully initiated if ret >= 0 { + conn.nwrite += int64(bufLen) // Return the number of bytes to be written return bufLen } @@ -1397,7 +1392,6 @@ const ( type readWaiter struct { rc requestAndChan waitForBodyRead chan bool - eofc chan struct{} } // setTaskId Set taskId to the task's userdata as a unique identifier @@ -1411,95 +1405,51 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { case write: if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake task error!\n")) + log.Printf("[readWriteLoop::write]handshake task error!\n") return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("unexpected task type\n") + return fmt.Errorf("[readWriteLoop::write]unexpected task type\n") } return nil case read: if task.Type() == hyper.TaskError { - c.Printf(c.Str("write task error!\n")) + log.Printf("[readWriteLoop::read]write task error!\n") return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - return fmt.Errorf("unexpected task type\n") + c.Printf(c.Str("[readWriteLoop::read]unexpected task type\n")) + return errors.New("[readWriteLoop::read]unexpected task type\n") } return nil case readDone: if task.Type() == hyper.TaskError { - c.Printf(c.Str("read error!\n")) + log.Printf("[readWriteLoop::readDone]read response body error!\n") return fail((*hyper.Error)(task.Value())) } return nil case notSet: } - return fmt.Errorf("unexpected TaskId\n") + return errors.New("[readWriteLoop]unexpected task type\n") } // fail prints the error details and panics func fail(err *hyper.Error) error { if err != nil { - c.Printf(c.Str("error code: %d\n"), err.Code()) + c.Printf(c.Str("[readWriteLoop]error code: %d\n"), err.Code()) // grab the error details var errBuf [256]c.Char errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + c.Printf(c.Str("[readWriteLoop]details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) // clean up the error err.Free() - return fmt.Errorf("hyper request error, error code: %d\n", int(err.Code())) + return fmt.Errorf("[readWriteLoop]hyper request error, error code: %d\n", int(err.Code())) } return nil } -// freeResources frees the resources -func freeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { - // Cleaning up before exiting - if task != nil { - task.Free() - } - if respBody != nil { - respBody.Free() - } - if bodyWriter != nil { - bodyWriter.Close() - } - if exec != nil { - exec.Free() - } - (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) - freeConnData(pc.conn) - - closeChannels(rc, pc) -} - -// closeChannels closes the channels -func closeChannels(rc requestAndChan, pc *persistConn) { - // Closing the channel - close(rc.ch) - close(pc.reqch) -} - -// freeConnData frees the connection data -func freeConnData(conn *connData) { - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil - } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil - } -} - // ---------------------------------------------------------- // error values for debugging and testing, not seen by users. @@ -1614,7 +1564,7 @@ func canonicalAddr(url *url.URL) string { if port == "" { port = portMap[url.Scheme] } - return xnet.JoinHostPort(idnaASCIIFromURL(url), port) + return net.JoinHostPort(idnaASCIIFromURL(url), port) } // persistConn wraps a connection, usually a persistent one @@ -1625,22 +1575,18 @@ type persistConn struct { // If it's non-nil, the rest of the fields are unused. alt RoundTripper - //br *bufio.Reader // from conn - //bw *bufio.Writer // to conn - //nwrite int64 // bytes written - //writech chan writeRequest // written by roundTrip; read by writeLoop - //closech chan struct{} // closed when conn closed - - t *Transport - cacheKey connectMethodKey - conn *connData - nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; read by readWriteLoop - writech chan writeRequest // written by roundTrip; read by writeLoop(Already merged into reqch) - closech chan struct{} // closed when conn closed - writeLoopDone chan struct{} // closed when write loop ends - - isProxy bool + t *Transport + cacheKey connectMethodKey + conn *connData + //tlsState *tls.ConnectionState + //nwrite int64 // bytes written(Replaced by connData.nwrite) + reqch chan requestAndChan // written by roundTrip; read by readWriteLoop + writech chan writeRequest // written by roundTrip; read by readWriteLoop + closech chan struct{} // closed when conn closed + isProxy bool + + writeLoopDone chan struct{} // closed when readWriteLoop ends + mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed @@ -1653,8 +1599,6 @@ type persistConn struct { // hyper specific exec *hyper.Executor - opts *hyper.ClientConnOptions - io *hyper.Io } func (pc *persistConn) cancelRequest(err error) { @@ -1693,6 +1637,10 @@ func (pc *persistConn) closeLocked(err error) { } } pc.mutateHeaderFunc = nil + // hyper related + if pc.exec != nil { + pc.exec.Free() + } } // mapRoundTripError returns the appropriate error value for @@ -1738,14 +1686,14 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte } if _, ok := err.(transportReadFromServerError); ok { - if pc.nwrite == startBytesWritten { + if pc.conn.nwrite == startBytesWritten { return nothingWrittenError{err} } // Don't decorate return err } if pc.isBroken() { - if pc.nwrite == startBytesWritten { + if pc.conn.nwrite == startBytesWritten { return nothingWrittenError{err} } return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err) @@ -1887,7 +1835,7 @@ func (w *wantConn) cancel(t *Transport, err error) { w.err = err w.mu.Unlock() - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(w.cancel) //if pc != nil { // t.putOrCloseIdleConn(pc) //} diff --git a/x/net/http/util.go b/x/net/http/util.go index f2efb70..bfd9fc3 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -3,6 +3,11 @@ package http import ( "strings" "unicode" + "unicode/utf8" + + "golang.org/x/net/idna" + + "github.com/goplus/llgoexamples/x/net" ) /** @@ -206,6 +211,99 @@ func headerValueContainsToken(v string, token string) bool { // httpguts.headerV return tokenEqual(trimOWS(v), token) } +// PunycodeHostPort returns the IDNA Punycode version +// of the provided "host" or "host:port" string. +func PunycodeHostPort(v string) (string, error) { // httpguts.PunycodeHostPort + if isASCII(v) { + return v, nil + } + + host, port, err := net.SplitHostPort(v) + if err != nil { + // The input 'v' argument was just a "host" argument, + // without a port. This error should not be returned + // to the caller. + host = v + port = "" + } + host, err = idna.ToASCII(host) + if err != nil { + // Non-UTF-8? Not representable in Punycode, in any + // case. + return "", err + } + if port == "" { + return host, nil + } + return net.JoinHostPort(host, port), nil +} + +func isASCII(s string) bool { // httpguts.isASCII + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} + +// ValidHostHeader reports whether h is a valid host header. +func ValidHostHeader(h string) bool { // httpguts.ValidHostHeader + // The latest spec is actually this: + // + // http://tools.ietf.org/html/rfc7230#section-5.4 + // Host = uri-host [ ":" port ] + // + // Where uri-host is: + // http://tools.ietf.org/html/rfc3986#section-3.2.2 + // + // But we're going to be much more lenient for now and just + // search for any byte that's not a valid byte in any of those + // expressions. + for i := 0; i < len(h); i++ { + if !validHostByte[h[i]] { + return false + } + } + return true +} + +// See the validHostHeader comment. +var validHostByte = [256]bool{ // httpguts.validHostByte + '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, + '8': true, '9': true, + + 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, + 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, + 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, + 'y': true, 'z': true, + + 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, + 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, + 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true, + 'Y': true, 'Z': true, + + '!': true, // sub-delims + '$': true, // sub-delims + '%': true, // pct-encoded (and used in IPv6 zones) + '&': true, // sub-delims + '(': true, // sub-delims + ')': true, // sub-delims + '*': true, // sub-delims + '+': true, // sub-delims + ',': true, // sub-delims + '-': true, // unreserved + '.': true, // unreserved + ':': true, // IPv6address + Host expression's optional port + ';': true, // sub-delims + '=': true, // sub-delims + '[': true, + '\'': true, // sub-delims + ']': true, + '_': true, // unreserved + '~': true, // unreserved +} + // IsPrint returns whether s is ASCII and printable according to // https://tools.ietf.org/html/rfc20#section-4.2. func IsPrint(s string) bool { // ascii.IsPrint diff --git a/x/net/ipsock.go b/x/net/ipsock.go index 855a864..55e1b45 100644 --- a/x/net/ipsock.go +++ b/x/net/ipsock.go @@ -1,5 +1,11 @@ package net +import ( + "unsafe" + + "github.com/goplus/llgo/c" +) + // JoinHostPort combines host and port into a network address of the // form "host:port". If host contains a colon, as found in literal // IPv6 addresses, then JoinHostPort returns "[host]:port". @@ -8,17 +14,82 @@ package net func JoinHostPort(host, port string) string { // We assume that host is a literal IPv6 address if host has // colons. + if IndexByteString(host, ':') >= 0 { return "[" + host + "]:" + port } return host + ":" + port } -func IndexByteString(s string, c byte) int { - for i := 0; i < len(s); i++ { - if s[i] == c { - return i +// SplitHostPort splits a network address of the form "host:port", +// "host%zone:port", "[host]:port" or "[host%zone]:port" into host or +// host%zone and port. +// +// A literal IPv6 address in hostport must be enclosed in square +// brackets, as in "[::1]:80", "[::1%lo0]:80". +// +// See func Dial for a description of the hostport parameter, and host +// and port results. +func SplitHostPort(hostport string) (host, port string, err error) { + const ( + missingPort = "missing port in address" + tooManyColons = "too many colons in address" + ) + addrErr := func(addr, why string) (host, port string, err error) { + return "", "", &AddrError{Err: why, Addr: addr} + } + j, k := 0, 0 + + // The port starts after the last colon. + i := last(hostport, ':') + if i < 0 { + return addrErr(hostport, missingPort) + } + + if hostport[0] == '[' { + // Expect the first ']' just before the last ':'. + end := IndexByteString(hostport, ']') + if end < 0 { + return addrErr(hostport, "missing ']' in address") } + switch end + 1 { + case len(hostport): + // There can't be a ':' behind the ']' now. + return addrErr(hostport, missingPort) + case i: + // The expected result. + default: + // Either ']' isn't followed by a colon, or it is + // followed by a colon that is not the last one. + if hostport[end+1] == ':' { + return addrErr(hostport, tooManyColons) + } + return addrErr(hostport, missingPort) + } + host = hostport[1:end] + j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions + } else { + host = hostport[:i] + if IndexByteString(host, ':') >= 0 { + return addrErr(hostport, tooManyColons) + } + } + if IndexByteString(hostport[j:], '[') >= 0 { + return addrErr(hostport, "unexpected '[' in address") + } + if IndexByteString(hostport[k:], ']') >= 0 { + return addrErr(hostport, "unexpected ']' in address") + } + + port = hostport[i+1:] + return host, port, nil +} + +func IndexByteString(s string, ch byte) int { // bytealg.IndexByteString + ptr := unsafe.Pointer(unsafe.StringData(s)) + ret := c.Memchr(ptr, c.Int(ch), uintptr(len(s))) + if ret != nil { + return int(uintptr(ret) - uintptr(ptr)) } return -1 } diff --git a/x/net/net.go b/x/net/net.go new file mode 100644 index 0000000..3267d90 --- /dev/null +++ b/x/net/net.go @@ -0,0 +1,20 @@ +package net + +type AddrError struct { + Err string + Addr string +} + +func (e *AddrError) Error() string { + if e == nil { + return "" + } + s := e.Err + if e.Addr != "" { + s = "address " + e.Addr + ": " + s + } + return s +} + +func (e *AddrError) Timeout() bool { return false } +func (e *AddrError) Temporary() bool { return false } diff --git a/x/net/parse.go b/x/net/parse.go new file mode 100644 index 0000000..c110fcf --- /dev/null +++ b/x/net/parse.go @@ -0,0 +1,12 @@ +package net + +// Index of rightmost occurrence of b in s. +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +} From cebff180c6ce1c390da2bb17ed0ad3372cfbf18d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 26 Aug 2024 17:46:17 +0800 Subject: [PATCH 17/21] WIP(x/http/client): Categorize and write req.headers & perform unwrapBody operation on req.Reader --- x/net/http/_demo/get/get.go | 1 - x/net/http/_demo/headers/headers.go | 3 +- x/net/http/_demo/upload/upload.go | 5 +- x/net/http/header.go | 91 +++++++++++++++ x/net/http/request.go | 172 ++++++++++------------------ x/net/http/server.go | 7 ++ x/net/http/transfer.go | 110 ++++++++++++++++++ x/net/http/transport.go | 59 +++++----- 8 files changed, 302 insertions(+), 146 deletions(-) diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 79c18ba..6bc5b06 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -14,7 +14,6 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - fmt.Println(resp.Proto) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index 71d42b7..aa2e5d6 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -16,7 +16,7 @@ func main() { } //req.Header.Set("accept", "*/*") - req.Header.Set("accept-encoding", "identity") + req.Header.Set("accept-encoding", "gzip") //req.Header.Set("cache-control", "no-cache") //req.Header.Set("pragma", "no-cache") //req.Header.Set("priority", "u=0, i") @@ -36,6 +36,7 @@ func main() { println(err.Error()) return } + fmt.Println(resp.Status) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index 86b57e9..fe7256b 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -10,7 +10,9 @@ import ( func main() { url := "http://httpbin.org/post" + //url := "http://localhost:8080" filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path + //filePath := "/Users/spongehah/Downloads/xiaoshuo.txt" // Replace with your file path file, err := os.Open(filePath) if err != nil { @@ -33,7 +35,8 @@ func main() { return } defer resp.Body.Close() - + fmt.Println("Status:", resp.Status) + resp.PrintHeaders() respBody, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/header.go b/x/net/http/header.go index 6515c48..0d1e2cc 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -3,6 +3,9 @@ package http import ( "fmt" "net/textproto" + "sort" + "strings" + "sync" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -108,6 +111,94 @@ func (h Header) Clone() Header { return h2 } +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +// Write writes a header in wire format. +func (h Header) Write(reqHeaders *hyper.Headers) error { + return h.write(reqHeaders) +} + +func (h Header) write(reqHeaders *hyper.Headers) error { + return h.writeSubset(reqHeaders, nil) +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +// Keys are not canonicalized before checking the exclude map. +func (h Header) WriteSubset(reqHeaders *hyper.Headers, exclude map[string]bool) error { + return h.writeSubset(reqHeaders, exclude) +} + +func (h Header) writeSubset(reqHeaders *hyper.Headers, exclude map[string]bool) error { + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { + if !ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + if reqHeaders.Add(&[]byte(kv.key)[0], c.Strlen(c.AllocaCStr(kv.key)), &[]byte(v)[0], c.Strlen(c.AllocaCStr(v))) != hyper.OK { + headerSorterPool.Put(sorter) + return fmt.Errorf("error adding header %s: %s\n", kv.key, v) + } + //if trace != nil && trace.WroteHeaderField != nil { + // formattedVals = append(formattedVals, v) + //} + } + //if trace != nil && trace.WroteHeaderField != nil { + // trace.WroteHeaderField(kv.key, formattedVals) + // formattedVals = nil + //} + } + + headerSorterPool.Put(sorter) + return nil +} + // hasToken reports whether token appears with v, ASCII // case-insensitive, with space or comma boundaries. // token must be all lowercase. diff --git a/x/net/http/request.go b/x/net/http/request.go index be81f0e..6d74296 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -14,7 +14,6 @@ import ( "golang.org/x/net/idna" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" ) @@ -167,54 +166,6 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R return req, nil } -//func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { -// req := (*postReq)(userdata) -// buf := req.hyperBuf.Bytes() -// len := req.hyperBuf.Len() -// n, err := req.req.Body.Read(unsafe.Slice(buf, len)) -// if err != nil { -// if err == io.EOF { -// *chunk = nil -// return hyper.PollReady -// } -// fmt.Println("error reading upload file: ", err) -// return hyper.PollError -// } -// if n > 0 { -// *chunk = req.hyperBuf -// return hyper.PollReady -// } -// if n == 0 { -// *chunk = nil -// return hyper.PollReady -// } -// -// fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) -// return hyper.PollError -//} - -// setHeaders sets the headers of the request -func (r *Request) setHeaders(hyperReq *hyper.Request) error { - headers := hyperReq.Headers() - for key, values := range r.Header { - valueLen := len(values) - if valueLen > 1 { - for _, value := range values { - if headers.Add(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(value)[0], c.Strlen(c.AllocaCStr(value))) != hyper.OK { - return fmt.Errorf("error adding header %s: %s\n", key, value) - } - } - } else if valueLen == 1 { - if headers.Set(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(values[0])[0], c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { - return fmt.Errorf("error setting header %s: %s\n", key, values[0]) - } - } else { - return fmt.Errorf("error setting header %s: empty value\n", key) - } - } - return nil -} - func (r *Request) expectsContinue() bool { return hasToken(r.Header.get("Expect"), "100-continue") } @@ -289,6 +240,21 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { // the Request. var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed at the time of Go 1.1 release because the former User-Agent +// had ended up blocked by some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go-http-client/1.1" + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + // extraHeaders may be nil // waitForContinue may be nil // always closes body @@ -302,16 +268,6 @@ func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.Clien // }() //} - //closed := false - //defer func() { - // if closed { - // return - // } - // if closeErr := r.closeBody(); closeErr != nil && err == nil { - // err = closeErr - // } - //}() - // Prepare the hyper.Request hyperReq, err := r.newHyperRequest(usingProxy, extraHeader) if err != nil { @@ -387,9 +343,6 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R return nil, errors.New("net/http: can't write control character in Request.URL") } - - - // Prepare the hyper request hyperReq := hyper.NewRequest() @@ -409,29 +362,55 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { return nil, fmt.Errorf("error setting header: Host: %s\n", host) } - err = r.setHeaders(hyperReq) + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if r.Header.has("User-Agent") { + userAgent = r.Header.Get("User-Agent") + } + if userAgent != "" { + if reqHeaders.Set(&[]byte("User-Agent")[0], c.Strlen(c.Str("User-Agent")), &[]byte(userAgent)[0], c.Strlen(c.AllocaCStr(userAgent))) != hyper.OK { + return nil, fmt.Errorf("error setting header: User-Agent: %s\n", userAgent) + } + } + + // Process Body,ContentLength,Close,Trailer + //tw, err := newTransferWriter(r) + //if err != nil { + // return err + //} + //err = tw.writeHeader(w, trace) + err = r.writeHeader(reqHeaders) if err != nil { return nil, err } - if r.Body != nil { - // 100-continue - if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { - hyperReq.OnInformational(printInformational, nil) - } + err = r.Header.writeSubset(reqHeaders, reqWriteExcludeHeader) + if err != nil { + return nil, err + } - hyperReqBody := hyper.NewBody() - //buf := make([]byte, 2) - //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(2)) - reqData := &postReq{ - req: r, - buf: make([]byte, defaultChunkSize), - //hyperBuf: hyperBuf, + if extraHeader != nil { + err = extraHeader.write(reqHeaders) + if err != nil { + return nil, err } - hyperReqBody.SetUserdata(c.Pointer(reqData)) - hyperReqBody.SetDataFunc(setPostData) - //hyperReqBody.SetDataFunc(setPostDataNoCopy) - hyperReq.SetBody(hyperReqBody) + } + + //if trace != nil && trace.WroteHeaders != nil { + // trace.WroteHeaders() + //} + + // Wait for 100-continue if expected. + if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { + hyperReq.OnInformational(printInformational, nil) + } + + // Write body and trailer + err = r.writeBody(hyperReq) + if err != nil { + return nil, err } return hyperReq, nil @@ -442,41 +421,6 @@ func printInformational(userdata c.Pointer, resp *hyper.Response) { fmt.Println("Informational (1xx): ", status) } -type postReq struct { - req *Request - buf []byte - //hyperBuf *hyper.Buf -} - -func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - req := (*postReq)(userdata) - n, err := req.req.Body.Read(req.buf) - if err != nil { - if err == io.EOF { - println("EOF") - *chunk = nil - req.req.Body.Close() - return hyper.PollReady - } - fmt.Println("error reading request body: ", err) - return hyper.PollError - } - if n > 0 { - *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) - return hyper.PollReady - } - if n == 0 { - println("n == 0") - *chunk = nil - req.req.Body.Close() - return hyper.PollReady - } - - req.req.Body.Close() - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) - return hyper.PollError -} - func validMethod(method string) bool { /* Method = "OPTIONS" ; Section 9.2 diff --git a/x/net/http/server.go b/x/net/http/server.go index f38cbd0..5c4c58d 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -10,3 +10,10 @@ package http // size is anyway. (if we have the bytes on the machine, we might as // well read them) const maxPostHandlerReadBytes = 256 << 10 + +type readResult struct { + _ incomparable + n int + err error + b byte // byte read, if n == 1 +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index cf96f84..b0a52fb 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -5,10 +5,15 @@ import ( "fmt" "io" "net/textproto" + "reflect" "strconv" "strings" "sync" "unicode/utf8" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/os" + "github.com/goplus/llgoexamples/rust/hyper" ) type transferReader struct { @@ -610,3 +615,108 @@ func lowerASCII(b byte) byte { // isOWS reports whether b is an optional whitespace byte, as defined // by RFC 7230 section 3.2.3. func isOWS(b byte) bool { return b == ' ' || b == '\t' } + +// writeHeader Write Content-Length and/or Transfer-Encoding and/or Trailer header +func (r *Request) writeHeader(reqHeaders *hyper.Headers) error { + if r.Close && !hasToken(r.Header.get("Connection"), "close") { + if reqHeaders.Set(&[]byte("Connection")[0], c.Strlen(c.Str("Connection")), &[]byte("close")[0], c.Strlen(c.Str("close"))) != hyper.OK { + return fmt.Errorf("error setting header: Connection: %s\n", "close") + } + } + + // 'Content-Length' and 'Transfer-Encoding:chunked' are already handled by hyper + + // Write Trailer header + // TODO(spongehah) Trailer(writeHeader) + + return nil +} + +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) +var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { + io.Reader + io.WriterTo +}{})) + +// unwrapNopCloser return the underlying reader and true if r is a NopCloser +// else it return false. +func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) { + switch reflect.TypeOf(r) { + case nopCloserType, nopCloserWriterToType: + return reflect.ValueOf(r).Field(0).Interface().(io.Reader), true + default: + return nil, false + } +} + +// unwrapBody unwraps the body's inner reader if it's a +// nopCloser. This is to ensure that body writes sourced from local +// files (*os.File types) are properly optimized. +// +// This function is only intended for use in writeBody. +func (req *Request) unwrapBody() io.Reader { + if r, ok := unwrapNopCloser(req.Body); ok { + return r + } + if r, ok := req.Body.(*readTrackingBody); ok { + r.didRead = true + return r.ReadCloser + } + return req.Body +} + +func (r *Request) writeBody(hyperReq *hyper.Request) error { + if r.Body != nil { + var body = r.unwrapBody() + hyperReqBody := hyper.NewBody() + buf := make([]byte, defaultChunkSize) + //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(defaultChunkSize)) + reqData := &bodyReq{ + body: body, + buf: buf, + //hyperBuf: hyperBuf, + closeBody: r.closeBody, + } + hyperReqBody.SetUserdata(c.Pointer(reqData)) + hyperReqBody.SetDataFunc(setPostData) + hyperReq.SetBody(hyperReqBody) + } + return nil +} + +type bodyReq struct { + body io.Reader + buf []byte + //hyperBuf *hyper.Buf + closeBody func() error +} + +func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + req := (*bodyReq)(userdata) + n, err := req.body.Read(req.buf) + //buf := req.hyperBuf.Bytes() + //bufLen := req.hyperBuf.Len() + //n, err := req.body.Read(unsafe.Slice(buf, bufLen)) + if err != nil { + if err == io.EOF { + *chunk = nil + req.closeBody() + return hyper.PollReady + } + fmt.Println("error reading request body: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) + //*chunk = req.hyperBuf + return hyper.PollReady + } + if n == 0 { + *chunk = nil + req.closeBody() + return hyper.PollReady + } + req.closeBody() + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} diff --git a/x/net/http/transport.go b/x/net/http/transport.go index fe3efc3..5d7d48a 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -16,8 +16,8 @@ import ( "github.com/goplus/llgo/c/libuv" cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgoexamples/x/net" "github.com/goplus/llgoexamples/rust/hyper" + "github.com/goplus/llgoexamples/x/net" ) // DefaultTransport is the default implementation of Transport and is @@ -733,7 +733,6 @@ func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, erro conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) } - libuv.InitTcp(loop, &conn.TcpHandle) libuv.InitTcp(loop, &conn.TcpHandle) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) @@ -781,25 +780,26 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // uncompress the gzip stream if we were the layer that // requested it. requestedGzip := false - if !pc.t.DisableCompression && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - req.Method != "HEAD" { - // Request gzip only, not deflate. Deflate is ambiguous and - // not as universally supported anyway. - // See: https://zlib.net/zlib_faq.html#faq39 - // - // Note that we don't request this for HEAD requests, - // due to a bug in nginx: - // https://trac.nginx.org/nginx/ticket/358 - // https://golang.org/issue/5522 - // - // We don't request gzip if the request is for a range, since - // auto-decoding a portion of a gzipped document will just fail - // anyway. See https://golang.org/issue/8923 - requestedGzip = true - req.extraHeaders().Set("Accept-Encoding", "gzip") - } + // TODO(spongehah) gzip(pc.roundTrip) + //if !pc.t.DisableCompression && + // req.Header.Get("Accept-Encoding") == "" && + // req.Header.Get("Range") == "" && + // req.Method != "HEAD" { + // // Request gzip only, not deflate. Deflate is ambiguous and + // // not as universally supported anyway. + // // See: https://zlib.net/zlib_faq.html#faq39 + // // + // // Note that we don't request this for HEAD requests, + // // due to a bug in nginx: + // // https://trac.nginx.org/nginx/ticket/358 + // // https://golang.org/issue/5522 + // // + // // We don't request gzip if the request is for a range, since + // // auto-decoding a portion of a gzipped document will just fail + // // anyway. See https://golang.org/issue/8923 + // requestedGzip = true + // req.extraHeaders().Set("Accept-Encoding", "gzip") + //} // The 100-continue operation in Hyper is handled in the newHyperRequest function. @@ -1126,14 +1126,15 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } resp.Body = body - // TODO(spongehah) gzip fail(readWriteLoop) - if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - resp.Body = &gzipReader{body: body} - resp.Header.Del("Content-Encoding") - resp.Header.Del("Content-Length") - resp.ContentLength = -1 - resp.Uncompressed = true - } + // TODO(spongehah) gzip(pc.readWriteLoop) + //if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // resp.Body = &gzipReader{body: body} + // resp.Header.Del("Content-Encoding") + // resp.Header.Del("Content-Length") + // resp.ContentLength = -1 + // resp.Uncompressed = true + //} rw.waitForBodyRead = waitForBodyRead rw.rc = rc From c4d73157830519230bb04a52934fab473b241094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 29 Aug 2024 11:16:42 +0800 Subject: [PATCH 18/21] WIP(x/net/http/client): Extract outwards libuv.Loop and timeout logic --- go.mod | 2 +- go.sum | 4 +- x/net/http/_demo/timeout/timeout.go | 4 +- x/net/http/client.go | 50 +- x/net/http/request.go | 17 +- x/net/http/response.go | 1 + x/net/http/transfer.go | 15 +- x/net/http/transport.go | 989 ++++++++++++++++------------ x/net/http/util.go | 2 +- 9 files changed, 639 insertions(+), 445 deletions(-) diff --git a/go.mod b/go.mod index f961f75..e893515 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/goplus/llgoexamples go 1.20 require ( - github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 + github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b golang.org/x/net v0.28.0 ) diff --git a/go.sum b/go.sum index 4c64063..5d7faad 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 h1:fqqbWhWaoseSplLJF8OTkNGl4Kruqm1wQWT/Yooq6E4= -github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b h1:iC0vVA8F2DNJ9wVyHI9fP9U0nM+si3LSQJ1TtGftXyo= +github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= diff --git a/x/net/http/_demo/timeout/timeout.go b/x/net/http/_demo/timeout/timeout.go index ddb2d25..62b2c9d 100644 --- a/x/net/http/_demo/timeout/timeout.go +++ b/x/net/http/_demo/timeout/timeout.go @@ -10,8 +10,8 @@ import ( func main() { client := &http.Client{ - Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - //Timeout: time.Second * 5, + //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + Timeout: time.Second * 5, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { diff --git a/x/net/http/client.go b/x/net/http/client.go index bf1bfd4..d56f5f2 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -8,6 +8,7 @@ import ( "io" "log" "net/url" + "reflect" "sort" "strings" "sync" @@ -158,6 +159,9 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { Host: host, Cancel: ireq.Cancel, ctx: ireq.ctx, + + timer: ireq.timer, + timeoutch: ireq.timeoutch, } if includeBody && ireq.GetBody != nil { req.Body, err = ireq.GetBody() @@ -305,11 +309,14 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } // TODO(spongehah) timeout(send) - req.timeoutch = make(chan struct{}, 1) + req.deadline = deadline + if deadline.IsZero() { + didTimeout = alwaysFalse + } else { + didTimeout = func() bool { return req.timer.GetDueIn() == 0 } + } //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) - sub := deadline.Sub(time.Now()) - req.timeout = sub resp, err = rt.RoundTrip(req) if err != nil { //stopTimer() @@ -469,6 +476,34 @@ func (b *cancelTimerBody) Close() error { return err } +// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// maintained by the Go team and known to implement the latest +// optional semantics (notably contexts). The Request is used +// to check whether this particular request is using an alternate protocol, +// in which case we need to check the RoundTripper for that protocol. +func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { + switch t := rt.(type) { + case *Transport: + if altRT := t.alternateRoundTripper(req); altRT != nil { + return knownRoundTripperImpl(altRT, req) + } + return true + // TODO(spongehah) + //case *http2Transport, http2noDialH2RoundTripper: + // return true + } + // There's a very minor chance of a false positive with this. + // Instead of detecting our golang.org/x/net/http2.Transport, + // it might detect a Transport type in a different http2 + // package. But I know of none, and the only problem would be + // some temporarily leaked goroutines if the transport didn't + // support contexts. So this is a good enough heuristic: + if reflect.TypeOf(rt).String() == "*http2.Transport" { + return true + } + return false +} + // setRequestCancel sets req.Cancel and adds a deadline context to req // if deadline is non-zero. The RoundTripper's type is used to // determine whether the legacy CancelRequest behavior should be used. @@ -482,11 +517,10 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi if deadline.IsZero() { return nop, alwaysFalse } - //knownTransport := knownRoundTripperImpl(rt, req) + knownTransport := knownRoundTripperImpl(rt, req) oldCtx := req.Context() - //if req.Cancel == nil && knownTransport { - if req.Cancel == nil { + if req.Cancel == nil && knownTransport { // If they already had a Request.Context that's // expiring sooner, do nothing: if !timeBeforeContextDeadline(deadline, oldCtx) { @@ -504,7 +538,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) } - cancel := make(chan struct{}, 1) + cancel := make(chan struct{}) req.Cancel = cancel doCancel := func() { @@ -518,7 +552,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi } } - stopTimerCh := make(chan struct{}, 1) + stopTimerCh := make(chan struct{}) var once sync.Once stopTimer = func() { once.Do(func() { diff --git a/x/net/http/request.go b/x/net/http/request.go index 6d74296..658b033 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/goplus/llgo/c/libuv" "golang.org/x/net/idna" "github.com/goplus/llgo/c" @@ -37,12 +38,14 @@ type Request struct { RemoteAddr string RequestURI string //TLS *tls.ConnectionState - Cancel <-chan struct{} - timeoutch chan struct{} //optional + Cancel <-chan struct{} Response *Response - timeout time.Duration ctx context.Context + + deadline time.Time + timeoutch chan struct{} //tmp timeout + timer *libuv.Timer } const defaultChunkSize = 8192 @@ -117,6 +120,7 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R Header: make(Header), Body: rc, Host: u.Host, + timer: nil, } if body != nil { switch v := body.(type) { @@ -258,7 +262,7 @@ var reqWriteExcludeHeader = map[string]bool{ // extraHeaders may be nil // waitForContinue may be nil // always closes body -func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.ClientConn, exec *hyper.Executor) (err error) { +func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hyper.Executor) (err error) { //trace := httptrace.ContextClientTrace(r.Context()) //if trace != nil && trace.WroteRequest != nil { // defer func() { @@ -269,13 +273,14 @@ func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.Clien //} // Prepare the hyper.Request - hyperReq, err := r.newHyperRequest(usingProxy, extraHeader) + hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra) if err != nil { return err } // Send it! sendTask := client.Send(hyperReq) - setTaskId(sendTask, read) + taskData.taskId = read + sendTask.SetUserdata(c.Pointer(taskData)) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { err = errors.New("failed to send the request") diff --git a/x/net/http/response.go b/x/net/http/response.go index 8151ac2..6ff5b3d 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -73,6 +73,7 @@ func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { _, err := writer.Write(bytes) if err != nil { fmt.Println("Error writing to response body:", err) + writer.Close() return hyper.IterBreak } return hyper.IterContinue diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index b0a52fb..103200c 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -670,11 +670,9 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { var body = r.unwrapBody() hyperReqBody := hyper.NewBody() buf := make([]byte, defaultChunkSize) - //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(defaultChunkSize)) reqData := &bodyReq{ - body: body, - buf: buf, - //hyperBuf: hyperBuf, + body: body, + buf: buf, closeBody: r.closeBody, } hyperReqBody.SetUserdata(c.Pointer(reqData)) @@ -685,18 +683,14 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { } type bodyReq struct { - body io.Reader - buf []byte - //hyperBuf *hyper.Buf + body io.Reader + buf []byte closeBody func() error } func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { req := (*bodyReq)(userdata) n, err := req.body.Read(req.buf) - //buf := req.hyperBuf.Bytes() - //bufLen := req.hyperBuf.Len() - //n, err := req.body.Read(unsafe.Slice(buf, bufLen)) if err != nil { if err == io.EOF { *chunk = nil @@ -708,7 +702,6 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In } if n > 0 { *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) - //*chunk = req.hyperBuf return hyper.PollReady } if n == 0 { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 5d7d48a..062cbe9 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -10,14 +10,15 @@ import ( "net/url" "sync" "sync/atomic" + "time" "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/libuv" cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" + "github.com/goplus/llgo/x/net" "github.com/goplus/llgoexamples/rust/hyper" - "github.com/goplus/llgoexamples/x/net" ) // DefaultTransport is the default implementation of Transport and is @@ -33,7 +34,7 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -const defaultHTTPPort = "80" +const debugSwitch = true type Transport struct { altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme @@ -68,6 +69,11 @@ type Transport struct { // // Zero means no limit. MaxConnsPerHost int + + // libuv and hyper related + loopInitOnce sync.Once + loop *libuv.Loop + exec *hyper.Executor } // A cancelKey is the key of the reqCanceler map. @@ -82,29 +88,6 @@ type cancelKey struct { // any size (as long as it's first). type incomparable [0]func() -type requestAndChan struct { - _ incomparable - req *Request - cancelKey cancelKey - ch chan responseAndError // unbuffered; always send in select on callerGone - - // whether the Transport (as opposed to the user client code) - // added the Accept-Encoding gzip header. If the Transport - // set it, only then do we transparently decode the gzip. - addedGzip bool - - callerGone <-chan struct{} // closed when roundTrip caller has returned -} - -// A writeRequest is sent by the caller's goroutine to the -// writeLoop's goroutine to write a request while the read loop -// concurrently waits on both the write response and the server's -// reply. -type writeRequest struct { - req *transportRequest - ch chan<- error -} - // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { @@ -113,10 +96,9 @@ type responseAndError struct { err error } -type connAndTimeoutChan struct { - _ incomparable - conn *connData +type timeoutData struct { timeoutch chan struct{} + taskData *taskData } type readTrackingBody struct { @@ -205,8 +187,6 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } - // TODO(spongehah) cm.treq(connectMethod) - cm.treq = treq cm.onlyH1 = treq.requiresHTTP1() return cm, err } @@ -293,9 +273,110 @@ func (t *Transport) cancelRequest(key cancelKey, err error) bool { return cancel != nil } +func (t *Transport) close(err error) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + t.closeLocked(err) +} + +func (t *Transport) closeLocked(err error) { + if err != nil { + fmt.Println(err) + } + if t.loop != nil { + t.loop.Close() + } + if t.exec != nil { + t.exec.Free() + } +} + +func getMilliseconds(deadline time.Time) uint64 { + microseconds := deadline.Sub(time.Now()).Microseconds() + milliseconds := microseconds / 1e3 + if microseconds%1e3 != 0 { + milliseconds += 1 + } + return uint64(milliseconds) +} + // ---------------------------------------------------------- func (t *Transport) RoundTrip(req *Request) (*Response, error) { + if debugSwitch { + println("RoundTrip start") + defer println("RoundTrip end") + } + t.loopInitOnce.Do(func() { + t.loop = libuv.LoopNew() + t.exec = hyper.NewExecutor() + + //idle := &libuv.Idle{} + //libuv.InitIdle(t.loop, idle) + //(*libuv.Handle)(c.Pointer(idle)).SetData(c.Pointer(t)) + //idle.Start(readWriteLoop) + + checker := &libuv.Check{} + libuv.InitCheck(t.loop, checker) + (*libuv.Handle)(c.Pointer(checker)).SetData(c.Pointer(t)) + checker.Start(readWriteLoop) + + go t.loop.Run(libuv.RUN_DEFAULT) + }) + + // If timeout is set, start the timer + var didTimeout func() bool + var stopTimer func() + // Only the first request will initialize the timer + if req.timer == nil && !req.deadline.IsZero() { + req.timer = &libuv.Timer{} + req.timeoutch = make(chan struct{}, 1) + libuv.InitTimer(t.loop, req.timer) + ch := &timeoutData{ + timeoutch: req.timeoutch, + taskData: nil, + } + (*libuv.Handle)(c.Pointer(req.timer)).SetData(c.Pointer(ch)) + + req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) + if debugSwitch { + println("timer start") + } + didTimeout = func() bool { return req.timer.GetDueIn() == 0 } + stopTimer = func() { + close(req.timeoutch) + req.timer.Stop() + (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) + if debugSwitch { + println("timer close") + } + } + } else { + didTimeout = alwaysFalse + stopTimer = nop + } + + resp, err := t.doRoundTrip(req) + if err != nil { + stopTimer() + return nil, err + } + + if !req.deadline.IsZero() { + resp.Body = &cancelTimerBody{ + stop: stopTimer, + rc: resp.Body, + reqDidTimeout: didTimeout, + } + } + return resp, nil +} + +func (t *Transport) doRoundTrip(req *Request) (*Response, error) { + if debugSwitch { + println("doRoundTrip start") + defer println("doRoundTrip end") + } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() //trace := httptrace.ContextClientTrace(ctx) @@ -354,13 +435,19 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout(t.RoundTrip): because of that ctx not initialized ( initialized in setRequestCancel() ) + // TODO(spongehah) timeout(t.doRoundTrip) //select { //case <-ctx.Done(): // req.closeBody() // return nil, ctx.Err() //default: //} + select { + case <-req.timeoutch: + req.closeBody() + return nil, errors.New("request timeout!") + default: + } // treq gets modified by roundTrip, so we need to recreate for each retry. //treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} @@ -376,6 +463,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // pre-CONNECTed to https server. In any case, we'll be ready // to send it requests. pconn, err := t.getConn(treq, cm) + if err != nil { t.setReqCanceler(cancelKey, nil) req.closeBody() @@ -390,19 +478,24 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } else { resp, err = pconn.roundTrip(treq) } + if err == nil { resp.Request = origReq return resp, nil } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) Retry & ConnPool(t.RoundTrip) + // TODO(spongehah) Retry & ConnPool(t.doRoundTrip) return nil, err } } func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { - //req := treq.Request + if debugSwitch { + println("getConn start") + defer println("getConn end") + } + req := treq.Request //trace := treq.trace //ctx := req.Context() //if trace != nil && trace.GetConn != nil { @@ -413,6 +506,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi cm: cm, key: cm.key(), //ctx: ctx, + timeoutch: treq.timeoutch, ready: make(chan struct{}, 1), beforeDial: testHookPrePendingDial, afterDial: testHookPostPendingDial, @@ -458,10 +552,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // what caused w.err; if so, prefer to return the // cancellation error (see golang.org/issue/16049). select { + // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn //case <-req.Context().Done(): // return nil, req.Context().Err() + case <-req.timeoutch: + return nil, errors.New("timeout: req.Context().Err()") case err := <-cancelc: if err == errRequestCanceled { err = errRequestCanceledConn @@ -475,10 +572,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn - case <-treq.Request.timeoutch: - return nil, fmt.Errorf("request timeout\n") //case <-req.Context().Done(): - // return nil, req.Context().Err() + // return nil, + case <-req.timeoutch: + if debugSwitch { + println("getConn: timeoutch") + } + return nil, errors.New("timeout: req.Context().Err()\n") case err := <-cancelc: if err == errRequestCanceled { err = errRequestCanceledConn @@ -490,6 +590,10 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // queueForDial queues w to wait for permission to begin dialing. // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { + if debugSwitch { + println("queueForDial start") + defer println("queueForDial end") + } w.beforeDial() if t.MaxConnsPerHost <= 0 { @@ -522,9 +626,13 @@ func (t *Transport) queueForDial(w *wantConn) { // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { + if debugSwitch { + println("dialConnFor start") + defer println("dialConnFor end") + } defer w.afterDial() - pc, err := t.dialConn(w.ctx, w.cm) + pc, err := t.dialConn(w.timeoutch, w.cm) w.tryDeliver(pc, err) // TODO(spongehah) ConnPool(t.dialConnFor) //delivered := w.tryDeliver(pc, err) @@ -593,12 +701,19 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { } } -func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { +func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { + if debugSwitch { + println("dialConn start") + defer println("dialConn end") + } + select { + case <-timeoutch: + return + default: + } pconn = &persistConn{ t: t, cacheKey: cm.key(), - reqch: make(chan requestAndChan, 1), - writech: make(chan writeRequest, 1), closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), } @@ -611,7 +726,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } // return err //} - + // //if cm.scheme() == "https" && t.hasCustomTLSDialer() { // var err error // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) @@ -639,27 +754,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} else { //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(ctx, cm) + conn, err := t.dial(timeoutch, cm.addr()) if err != nil { return nil, err } pconn.conn = conn - // hyper specific - // Hookup the IO - hyperIo := newIoWithConnReadWrite(conn) - // We need an executor generally to poll futures - exec := hyper.NewExecutor() - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(exec) - pconn.exec = exec - // send the handshake - handshakeTask := hyper.Handshake(hyperIo, opts) - setTaskId(handshakeTask, write) - // Let's wait for the handshake to finish... - exec.Push(handshakeTask) - //if cm.scheme() == "https" { // var firstTLSHost string // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { @@ -670,7 +770,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} //} - + select { + case <-timeoutch: + conn.Close() + return + default: + } // TODO(spongehah) Proxy(https/sock5)(t.dialConn) // Proxy setup. switch { @@ -704,36 +809,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} - go pconn.readWriteLoop(libuv.DefaultLoop()) - + select { + case <-timeoutch: + conn.Close() + return + default: + } return pconn, nil } -func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, error) { - treq := cm.treq - host := treq.URL.Hostname() - port := treq.URL.Port() - if port == "" { - port = defaultHTTPPort +func (t *Transport) dial(timeoutch chan struct{}, addr string) (*connData, error) { + if debugSwitch { + println("dial start") + defer println("dial end") } - loop := libuv.DefaultLoop() - conn := new(connData) - if conn == nil { - return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err } - // If timeout is set, start the timer - if treq.timeout > 0 { - libuv.InitTimer(loop, &conn.TimeoutTimer) - ct := &connAndTimeoutChan{ - conn: conn, - timeoutch: treq.Request.timeoutch, - } - (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) - conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) - } + conn := new(connData) - libuv.InitTcp(loop, &conn.TcpHandle) + libuv.InitTcp(t.loop, &conn.TcpHandle) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) var hints cnet.AddrInfo @@ -744,14 +841,12 @@ func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, erro var res *cnet.AddrInfo status := cnet.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { - close(treq.Request.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { - close(treq.Request.timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } @@ -760,6 +855,10 @@ func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, erro } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { + if debugSwitch { + println("roundTrip start") + defer println("roundTrip end") + } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { // TODO(spongehah) ConnPool(pc.roundTrip) @@ -819,43 +918,61 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } }() - const debugRoundTrip = false // Debug switch provided for developers - // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. - - // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). startBytesWritten := pc.conn.nwrite writeErrCh := make(chan error, 1) - pc.writech <- writeRequest{req: req, ch: writeErrCh} - - // Send the request to readWriteLoop(). resc := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{ - req: req.Request, - cancelKey: req.cancelKey, - ch: resc, + + // Hookup the IO + hyperIo := newIoWithConnReadWrite(pc.conn) + // We need an executor generally to poll futures + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(pc.t.exec) + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + taskData := &taskData{ + taskId: write, + req: req, + pc: pc, addedGzip: requestedGzip, + writeErrCh: writeErrCh, callerGone: gone, + resc: resc, } + handshakeTask.SetUserdata(c.Pointer(taskData)) + // Send the request to readWriteLoop(). + // Let's wait for the handshake to finish... + + pc.t.exec.Push(handshakeTask) + async := &libuv.Async{} + pc.t.loop.Async(async, asyncCb) + async.Send() //var respHeaderTimer <-chan time.Time //cancelChan := req.Request.Cancel //ctxDoneChan := req.Context().Done() + timeoutch := req.timeoutch pcClosed := pc.closech canceled := false for { testHookWaitResLoop() - + if debugSwitch { + println("roundTrip for") + } select { case err := <-writeErrCh: - if debugRoundTrip { - //req.logf("writeErrCh resv: %T/%#v", err, err) + if debugSwitch { + println("roundTrip: writeErrch") } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) + if pc.conn.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } return nil, pc.mapRoundTripError(req, startBytesWritten, err) } //if d := pc.t.ResponseHeaderTimeout; d > 0 { @@ -867,69 +984,53 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // respHeaderTimer = timer.C //} case <-pcClosed: + if debugSwitch { + println("roundTrip: pcClosed") + } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { - if debugRoundTrip { - //req.logf("closech recv: %T %#v", pc.closed, pc.closed) - } return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) } //case <-respHeaderTimer: case re := <-resc: + if debugSwitch { + println("roundTrip: resc") + } if (re.res == nil) == (re.err == nil) { - println(1) return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } - if debugRoundTrip { - println(2) - //req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) - } if re.err != nil { - println(3) return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil // TODO(spongehah) cancel(pc.roundTrip) //case <-cancelChan: - case <-req.Request.timeoutch: - return nil, fmt.Errorf("request timeout\n") + // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) + // cancelChan = nil + //case <-ctxDoneChan: + // canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) + // cancelChan = nil + // ctxDoneChan = nil + case <-timeoutch: + if debugSwitch { + println("roundTrip: timeoutch") + } + canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) + timeoutch = nil + return nil, errors.New("request timeout") } } } +func asyncCb(async *libuv.Async) { + println("async called") +} + // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. -func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { - // writeLoop related - defer close(pc.writeLoopDone) - - // readLoop related - closeErr := errReadLoopExiting // default value, if not changed below - defer func() { - pc.close(closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //pc.t.removeIdleConn(pc) - }() - - //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { - // if err := pc.t.tryPutIdleConn(pc); err != nil { - // closeErr = err - // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - // } - // return false - // } - // if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - // } - // return true - //} - - // eofc is used to block caller goroutines reading from Response.Body - // at EOF until this goroutines has (potentially) added the connection - // back to the idle pool. - eofc := make(chan struct{}, 1) - defer close(eofc) // unblock reader on errors +func readWriteLoop(idle *libuv.Check) { + println("polling") + t := (*Transport)((*libuv.Handle)(c.Pointer(idle)).GetData()) // Read this once, before loop starts. (to avoid races in tests) testHookMu.Lock() @@ -938,281 +1039,334 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { const debugReadWriteLoop = true // Debug switch provided for developers - if debugReadWriteLoop { - println("readWriteLoop start") - } - // The polling state machine! // Poll all ready tasks and act on them... - alive := true - var bodyWriter *io.PipeWriter - var rw readWaiter - for alive { - select { - case <-pc.closech: - if debugReadWriteLoop { - println("closech") - } + for { + task := t.exec.Poll() + if task == nil { return - default: - task := pc.exec.Poll() - if task == nil { - loop.Run(libuv.RUN_ONCE) - continue - } - taskId := (taskId)(uintptr(task.Userdata())) + } + taskData := (*taskData)(task.Userdata()) + var taskId taskId + if taskData != nil { + taskId = taskData.taskId + } else { + taskId = notSet + } + if debugReadWriteLoop { + println("taskId: ", taskId) + } + switch taskId { + case write: if debugReadWriteLoop { - println("taskId: ", taskId) + println("write") } - switch taskId { - case write: - if debugReadWriteLoop { - println("write") - } - wr := <-pc.writech // blocking - startBytesWritten := pc.conn.nwrite - err := checkTaskType(task, write) - client := (*hyper.ClientConn)(task.Value()) + select { + case <-taskData.pc.closech: task.Free() - if err == nil { - // TODO(spongehah) Proxy(writeLoop) - err = wr.req.Request.write(pc.isProxy, wr.req.extra, client, pc.exec) - } - // For this request, no longer need the client - client.Free() - if bre, ok := err.(requestBodyReadError); ok { - err = bre.error - // Errors reading from the user's - // Request.Body are high priority. - // Set it here before sending on the - // channels below or calling - // pc.close() which tears down - // connections and causes other - // errors. - wr.req.setError(err) - } - if err != nil { - if pc.conn.nwrite == startBytesWritten { - err = nothingWrittenError{err} - } - //pc.writeErrCh <- err // to the body reader, which might recycle us - wr.ch <- err // to the roundTrip function - pc.close(err) - return - } + continue + default: + } - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } + err := checkTaskType(task, write) + client := (*hyper.ClientConn)(task.Value()) + task.Free() - err := checkTaskType(task, read) + if err == nil { + // TODO(spongehah) Proxy(writeLoop) + err = taskData.req.Request.write(client, taskData, t.exec) + } + // For this request, no longer need the client + client.Free() + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears down + // connections and causes other + // errors. + taskData.req.setError(err) + } + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + taskData.pc.close(err) + continue + } - pc.mu.Lock() - if pc.numExpectedResponses == 0 { - pc.closeLocked(errServerClosedIdle) - pc.mu.Unlock() - return - } - pc.mu.Unlock() + if debugReadWriteLoop { + println("write end") + } + case read: + if debugReadWriteLoop { + println("read") + } - rc := <-pc.reqch // blocking - //trace := httptrace.ContextClientTrace(rc.req.Context()) + if taskData.pc.closeErr == nil { + taskData.pc.closeErr = errReadLoopExiting + } + // TODO(spongehah) ConnPool(readWriteLoop) + //if taskData.pc.tryPutIdleConn == nil { + // //taskData.pc.tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + // // if err := pc.t.tryPutIdleConn(pc); err != nil { + // // closeErr = err + // // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // // trace.PutIdleConn(err) + // // } + // // return false + // // } + // // if trace != nil && trace.PutIdleConn != nil { + // // trace.PutIdleConn(nil) + // // } + // // return true + // //} + //} - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() + err := checkTaskType(task, read) - var resp *Response - var respBody *hyper.Body - if err == nil { - var pr *io.PipeReader - pr, bodyWriter = io.Pipe() - resp, err = ReadResponse(pr, rc.req, hyperResp) - respBody = hyperResp.Body() - } else { - err = transportReadFromServerError{err} - closeErr = err - } + taskData.pc.mu.Lock() + if taskData.pc.numExpectedResponses == 0 { + taskData.pc.closeLocked(errServerClosedIdle) + taskData.pc.mu.Unlock() - // No longer need the response - hyperResp.Free() + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue + } + taskData.pc.mu.Unlock() + + //trace := httptrace.ContextClientTrace(rc.req.Context()) + + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() + + var resp *Response + var respBody *hyper.Body + if err == nil { + var pr *io.PipeReader + pr, taskData.bodyWriter = io.Pipe() + resp, err = ReadResponse(pr, taskData.req.Request, hyperResp) + respBody = hyperResp.Body() + } else { + err = transportReadFromServerError{err} + taskData.pc.closeErr = err + } - if err != nil { - select { - case rc.ch <- responseAndError{err: err}: - case <-rc.callerGone: - return - } - return - } + // No longer need the response + hyperResp.Free() - // Response has been returned, stop the timer - if rc.req.timeout > 0 { - pc.conn.TimeoutTimer.Stop() - (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) + if err != nil { + select { + case taskData.resc <- responseAndError{err: err}: + case <-taskData.callerGone: + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue } + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue + } - pc.mu.Lock() - pc.numExpectedResponses-- - pc.mu.Unlock() + taskData.pc.mu.Lock() + taskData.pc.numExpectedResponses-- + taskData.pc.mu.Unlock() - bodyWritable := resp.bodyIsWritable() - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + bodyWritable := resp.bodyIsWritable() + hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 - if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { - // Don't do keep-alive on error if either party requested a close - // or we get an unexpected informational (1xx) response. - // StatusCode 100 is already handled above. - alive = false - } + if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + taskData.pc.alive = false + } - if !hasBody || bodyWritable { - //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) - pc.t.replaceReqCanceler(rc.cancelKey, nil) + if !hasBody || bodyWritable { + //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) + t.replaceReqCanceler(taskData.req.cancelKey, nil) + + // TODO(spongehah) ConnPool(readWriteLoop) + //// Put the idle conn back into the pool before we send the response + //// so if they process it quickly and make another request, they'll + //// get this same conn. But we use the unbuffered channel 'rc' + //// to guarantee that persistConn.roundTrip got out of its select + //// potentially waiting for this persistConn to close. + //taskData.pc.alive = taskData.pc.alive && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + if bodyWritable { + taskData.pc.closeErr = errCallerOwnsConn + } + select { + case taskData.resc <- responseAndError{res: resp}: + case <-taskData.callerGone: + // defer + taskData.pc.close(taskData.pc.closeErr) // TODO(spongehah) ConnPool(readWriteLoop) - //// Put the idle conn back into the pool before we send the response - //// so if they process it quickly and make another request, they'll - //// get this same conn. But we use the unbuffered channel 'rc' - //// to guarantee that persistConn.roundTrip got out of its select - //// potentially waiting for this persistConn to close. - //alive = alive && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - if bodyWritable { - closeErr = errCallerOwnsConn - } - - select { - case rc.ch <- responseAndError{res: resp}: - case <-rc.callerGone: - return - } - - // Now that they've read from the unbuffered channel, they're safely - // out of the select that also waits on this goroutine to die, so - // we're allowed to exit now if needed (if alive is false) - testHookReadLoopBeforeNextRead() + //t.removeIdleConn(pc) continue } + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + testHookReadLoopBeforeNextRead() + if taskData.pc.alive == false { + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + } + continue + } - waitForBodyRead := make(chan bool, 2) - body := &bodyEOFSignal{ - body: resp.Body, - earlyCloseFn: func() error { - waitForBodyRead <- false - <-eofc // will be closed by deferred call at the end of the function - return nil - }, - fn: func(err error) error { - isEOF := err == io.EOF - waitForBodyRead <- isEOF - if isEOF { - <-eofc // see comment above eofc declaration - } else if err != nil { - if cerr := pc.canceled(); cerr != nil { - return cerr - } + body := &bodyEOFSignal{ + body: resp.Body, + earlyCloseFn: func() error { + taskData.bodyWriter.Close() + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + if !isEOF { + if cerr := taskData.pc.canceled(); cerr != nil { + return cerr } - return err - }, - } - resp.Body = body - - // TODO(spongehah) gzip(pc.readWriteLoop) - //if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - // println("gzip reader") - // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") - // resp.Header.Del("Content-Length") - // resp.ContentLength = -1 - // resp.Uncompressed = true - //} - - rw.waitForBodyRead = waitForBodyRead - rw.rc = rc - bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) - setTaskId(bodyForeachTask, readDone) - pc.exec.Push(bodyForeachTask) - - // TODO(spongehah) select blocking(readWriteLoop) - //select { - //case rc.ch <- responseAndError{res: resp}: - //case <-rc.callerGone: - // return - //} - rc.ch <- responseAndError{res: resp} - - if debugReadWriteLoop { - println("read end") - } - case readDone: - // A background task of reading the response body is completed - if debugReadWriteLoop { - println("readDone") - } - if bodyWriter != nil { - bodyWriter.Close() - } - checkTaskType(task, readDone) + } + return err + }, + } + resp.Body = body + + // TODO(spongehah) gzip(pc.readWriteLoop) + //if taskData.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // resp.Body = &gzipReader{body: body} + // resp.Header.Del("Content-Encoding") + // resp.Header.Del("Content-Length") + // resp.ContentLength = -1 + // resp.Uncompressed = true + //} - hyperBodyEOF := task.Type() == hyper.TaskEmpty - // free the task - task.Free() + bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(taskData.bodyWriter)) + taskData.taskId = readDone + bodyForeachTask.SetUserdata(c.Pointer(taskData)) + t.exec.Push(bodyForeachTask) + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + + // TODO(spongehah) select blocking(readWriteLoop) + //select { + //case taskData.resc <- responseAndError{res: resp}: + //case <-taskData.callerGone: + // // defer + // taskData.pc.close(taskData.pc.closeErr) + // // TODO(spongehah) ConnPool(readWriteLoop) + // //t.removeIdleConn(pc) + // continue + //} + select { + case <-taskData.callerGone: + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue + default: + } + taskData.resc <- responseAndError{res: resp} - // Before looping back to the top of this function and peeking on - // the bufio.Reader, wait for the caller goroutine to finish - // reading the response body. (or for cancellation or death) - select { - case bodyEOF := <-rw.waitForBodyRead: - bodyEOF = bodyEOF && hyperBodyEOF - //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool - pc.t.replaceReqCanceler(rw.rc.cancelKey, nil) // before pc might return to idle pool - // TODO(spongehah) ConnPool(readWriteLoop) - //alive = alive && - // bodyEOF && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - eofc <- struct{}{} - // TODO(spongehah) cancel(pc.readWriteLoop) - //case <-rw.rc.req.Cancel: - // alive = false - // pc.t.CancelRequest(rw.rc.req) - //case <-rw.rc.req.Context().Done(): - // alive = false - // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - case <-pc.closech: - alive = false - } + if debugReadWriteLoop { + println("read end") + } + case readDone: + // A background task of reading the response body is completed + if debugReadWriteLoop { + println("readDone") + } + if taskData.bodyWriter != nil { + taskData.bodyWriter.Close() + } + checkTaskType(task, readDone) + + //bodyEOF := task.Type() == hyper.TaskEmpty + // free the task + task.Free() + + t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + // TODO(spongehah) ConnPool(readWriteLoop) + //taskData.pc.alive = taskData.pc.alive && + // bodyEOF && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + // TODO(spongehah) cancel(pc.readWriteLoop) + //case <-rw.rc.req.Cancel: + // taskData.pc.alive = false + // pc.t.CancelRequest(rw.rc.req) + //case <-rw.rc.req.Context().Done(): + // taskData.pc.alive = false + // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) + //case <-taskData.pc.closech: + // taskData.pc.alive = false + //} - testHookReadLoopBeforeNextRead() - if debugReadWriteLoop { - println("readDone end") - } - case notSet: - // A background task for hyper_client completed... - task.Free() + select { + case <-taskData.req.timeoutch: + continue + case <-taskData.pc.closech: + taskData.pc.alive = false + default: + } + + if taskData.pc.alive == false { + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + } + + testHookReadLoopBeforeNextRead() + if debugReadWriteLoop { + println("readDone end") } + case notSet: + // A background task for hyper_client completed... + task.Free() } } } // ---------------------------------------------------------- +type taskData struct { + taskId taskId + bodyWriter *io.PipeWriter + req *transportRequest + pc *persistConn + addedGzip bool + writeErrCh chan error + callerGone chan struct{} + resc chan responseAndError +} + type connData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect ReadBuf libuv.Buf - TimeoutTimer libuv.Timer ReadBufFilled uintptr nwrite int64 // bytes written(Replaced from persistConn's nwrite) ReadWaker *hyper.Waker @@ -1238,8 +1392,10 @@ func (conn *connData) Close() error { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + if debugSwitch { + println("connect start") + defer println("connect end") + } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) if status < 0 { @@ -1364,11 +1520,25 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui } // onTimeout is the libuv callback for a timeout -func onTimeout(handle *libuv.Timer) { - ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) - close(ct.timeoutch) - // Close the timer - (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) +func onTimeout(timer *libuv.Timer) { + if debugSwitch { + println("onTimeout start") + defer println("onTimeout end") + } + data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) + close(data.timeoutch) + timer.Stop() + + taskData := data.taskData + if taskData != nil { + pc := taskData.pc + pc.alive = false + pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) + // defer + pc.close(pc.closeErr) + // TODO(spongehah) ConnPool(onTimeout) + //t.removeIdleConn(pc) + } } // newIoWithConnReadWrite creates a new IO with read and write callbacks @@ -1390,17 +1560,6 @@ const ( readDone ) -type readWaiter struct { - rc requestAndChan - waitForBodyRead chan bool -} - -// setTaskId Set taskId to the task's userdata as a unique identifier -func setTaskId(task *hyper.Task, userData taskId) { - var data = userData - task.SetUserdata(unsafe.Pointer(uintptr(data))) -} - // checkTaskType checks the task type func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { @@ -1455,14 +1614,15 @@ func fail(err *hyper.Error) error { // error values for debugging and testing, not seen by users. var ( - errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") - errConnBroken = errors.New("http: putIdleConn: connection is in bad state") - errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") - errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") - errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") - errCloseIdleConns = errors.New("http: CloseIdleConnections called") - errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") - errIdleConnTimeout = errors.New("http: idle connection timeout") + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") + errReadWriteLoopExiting = errors.New("http: Transport.readWriteLoop exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -1581,9 +1741,7 @@ type persistConn struct { conn *connData //tlsState *tls.ConnectionState //nwrite int64 // bytes written(Replaced by connData.nwrite) - reqch chan requestAndChan // written by roundTrip; read by readWriteLoop - writech chan writeRequest // written by roundTrip; read by readWriteLoop - closech chan struct{} // closed when conn closed + closech chan struct{} // closed when conn closed isProxy bool writeLoopDone chan struct{} // closed when readWriteLoop ends @@ -1598,8 +1756,9 @@ type persistConn struct { // original Request given to RoundTrip is not modified) mutateHeaderFunc func(Header) - // hyper specific - exec *hyper.Executor + // other + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop } func (pc *persistConn) cancelRequest(err error) { @@ -1635,13 +1794,10 @@ func (pc *persistConn) closeLocked(err error) { pc.conn.Close() } close(pc.closech) + close(pc.writeLoopDone) } } pc.mutateHeaderFunc = nil - // hyper related - if pc.exec != nil { - pc.exec.Free() - } } // mapRoundTripError returns the appropriate error value for @@ -1742,8 +1898,7 @@ type connectMethod struct { // then targetAddr is not included in the connect method key, because the socket can // be reused for different targetAddr values. targetAddr string - treq *transportRequest // optional - onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 + onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } // connectMethodKey is the map key version of connectMethod, with a @@ -1808,10 +1963,11 @@ func (cm *connectMethod) proxyAuth() string { // These three options are racing against each other and use // wantConn to coordinate and agree about the winning outcome. type wantConn struct { - cm connectMethod - key connectMethodKey // cm.key() - ctx context.Context // context for dial - ready chan struct{} // closed when pc, err pair is delivered + cm connectMethod + key connectMethodKey // cm.key() + ctx context.Context // context for dial + timeoutch chan struct{} // tmp timeout to replace ctx + ready chan struct{} // closed when pc, err pair is delivered // hooks for testing to know when dials are done // beforeDial is called in the getConn goroutine when the dial is queued. @@ -1866,6 +2022,11 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { if w.pc == nil && w.err == nil { panic("net/http: internal error: misuse of tryDeliver") } + select { + case <-w.timeoutch: + pc.close(errors.New("request timeout: dialConn timeout")) + default: + } close(w.ready) return true } diff --git a/x/net/http/util.go b/x/net/http/util.go index bfd9fc3..bec22a8 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -7,7 +7,7 @@ import ( "golang.org/x/net/idna" - "github.com/goplus/llgoexamples/x/net" + "github.com/goplus/llgo/x/net" ) /** From 0d8cc277f2688593442b82f3f9f317bd4b9271b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Wed, 4 Sep 2024 18:09:04 +0800 Subject: [PATCH 19/21] WIP(x/net/http/client): Implement IdleConnPool --- x/net/http/_demo/get/get.go | 2 +- x/net/http/_demo/headers/headers.go | 2 +- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 2 +- x/net/http/_demo/post/post.go | 2 +- x/net/http/_demo/postform/postform.go | 2 - x/net/http/_demo/redirect/redirect.go | 2 +- x/net/http/_demo/reuseConn/reuseConn.go | 42 + x/net/http/_demo/timeout/timeout.go | 12 +- x/net/http/client.go | 8 +- x/net/http/request.go | 73 +- x/net/http/transport.go | 974 +++++++++++++----- x/net/http/util.go | 2 +- 12 files changed, 832 insertions(+), 291 deletions(-) create mode 100644 x/net/http/_demo/reuseConn/reuseConn.go diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 6bc5b06..6e91bd4 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -13,6 +13,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) @@ -21,5 +22,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index aa2e5d6..5538923 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -36,6 +36,7 @@ func main() { println(err.Error()) return } + defer resp.Body.Close() fmt.Println(resp.Status) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) @@ -44,5 +45,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 882bdc1..5662251 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -19,6 +19,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) resp.PrintHeaders() @@ -28,5 +29,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/post/post.go b/x/net/http/_demo/post/post.go index f169dfc..fd756b3 100644 --- a/x/net/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -15,6 +15,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status) body, err := io.ReadAll(resp.Body) if err != nil { @@ -22,5 +23,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/postform/postform.go b/x/net/http/_demo/postform/postform.go index eae4d6e..232c15d 100644 --- a/x/net/http/_demo/postform/postform.go +++ b/x/net/http/_demo/postform/postform.go @@ -20,12 +20,10 @@ func main() { return } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go index e4fdb92..f189255 100644 --- a/x/net/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -13,6 +13,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) resp.PrintHeaders() @@ -22,5 +23,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/reuseConn/reuseConn.go b/x/net/http/_demo/reuseConn/reuseConn.go new file mode 100644 index 0000000..bb460ce --- /dev/null +++ b/x/net/http/_demo/reuseConn/reuseConn.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + // Send request first time + resp, err := http.Get("https://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + resp.PrintHeaders() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + resp.Body.Close() + + // Send request second time + resp, err = http.Get("https://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + resp.PrintHeaders() + body, err = io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + resp.Body.Close() +} diff --git a/x/net/http/_demo/timeout/timeout.go b/x/net/http/_demo/timeout/timeout.go index 62b2c9d..a6930b1 100644 --- a/x/net/http/_demo/timeout/timeout.go +++ b/x/net/http/_demo/timeout/timeout.go @@ -10,24 +10,24 @@ import ( func main() { client := &http.Client{ - //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - Timeout: time.Second * 5, + Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + //Timeout: time.Second, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { - fmt.Println(err.Error()) + fmt.Println(err) return } resp, err := client.Do(req) if err != nil { - fmt.Println(err.Error()) + fmt.Println(err) return } + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - fmt.Println(err.Error()) + fmt.Println(err) return } println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/client.go b/x/net/http/client.go index d56f5f2..002397a 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -241,7 +241,6 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // didTimeout is non-nil only if err != nil. func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { - // TODO(spongehah) cookie(c.send) if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) @@ -309,13 +308,16 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } // TODO(spongehah) timeout(send) + //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + req.timeoutch = make(chan struct{}, 1) req.deadline = deadline + req.ctx.Done() if deadline.IsZero() { didTimeout = alwaysFalse + defer close(req.timeoutch) } else { didTimeout = func() bool { return req.timer.GetDueIn() == 0 } } - //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) resp, err = rt.RoundTrip(req) if err != nil { @@ -488,7 +490,7 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return knownRoundTripperImpl(altRT, req) } return true - // TODO(spongehah) + // TODO(spongehah) http2 //case *http2Transport, http2noDialH2RoundTripper: // return true } diff --git a/x/net/http/request.go b/x/net/http/request.go index 658b033..c5146ed 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -50,6 +50,30 @@ type Request struct { const defaultChunkSize = 8192 +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed at the time of Go 1.1 release because the former User-Agent +// had ended up blocked by some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go-http-client/1.1" + +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + // NewRequest wraps NewRequestWithContext using context.Background. func NewRequest(method, url string, body io.Reader) (*Request, error) { return NewRequestWithContext(context.Background(), method, url, body) @@ -188,6 +212,22 @@ func (r *Request) closeBody() error { return r.Body.Close() } +func (r *Request) isReplayable() bool { + if r.Body == nil || r.Body == NoBody || r.GetBody != nil { + switch valueOrDefault(r.Method, "GET") { + case "GET", "HEAD", "OPTIONS", "TRACE": + return true + } + // The Idempotency-Key, while non-standard, is widely used to + // mean a POST or other request is idempotent. See + // https://golang.org/issue/19943#issuecomment-421092421 + if r.Header.has("Idempotency-Key") || r.Header.has("X-Idempotency-Key") { + return true + } + } + return false +} + // Context returns the request's context. To change the context, use // Clone or WithContext. // @@ -240,25 +280,6 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } -// errMissingHost is returned by Write when there is no Host or URL present in -// the Request. -var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") - -// NOTE: This is not intended to reflect the actual Go version being used. -// It was changed at the time of Go 1.1 release because the former User-Agent -// had ended up blocked by some intrusion detection systems. -// See https://codereview.appspot.com/7532043. -const defaultUserAgent = "Go-http-client/1.1" - -// Headers that Request.Write handles itself and should be skipped. -var reqWriteExcludeHeader = map[string]bool{ - "Host": true, // not in Header map anyway - "User-Agent": true, - "Content-Length": true, - "Transfer-Encoding": true, - "Trailer": true, -} - // extraHeaders may be nil // waitForContinue may be nil // always closes body @@ -279,7 +300,6 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype } // Send it! sendTask := client.Send(hyperReq) - taskData.taskId = read sendTask.SetUserdata(c.Pointer(taskData)) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { @@ -482,11 +502,6 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } -// requestBodyReadError wraps an error from (*Request).write to indicate -// that the error came from a Read call on the Request.Body. -// This error type should not escape the net/http package to users. -type requestBodyReadError struct{ error } - func idnaASCII(v string) (string, error) { // TODO: Consider removing this check after verifying performance is okay. // Right now punycode verification, length checks, context checks, and the @@ -519,3 +534,11 @@ func removeZone(host string) string { } return host[:j] + host[i:] } + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 062cbe9..44d721d 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -2,6 +2,7 @@ package http import ( "compress/gzip" + "container/list" "context" "errors" "fmt" @@ -17,8 +18,8 @@ import ( "github.com/goplus/llgo/c/libuv" cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgo/x/net" "github.com/goplus/llgoexamples/rust/hyper" + "github.com/goplus/llgoexamples/x/net" ) // DefaultTransport is the default implementation of Transport and is @@ -28,7 +29,9 @@ import ( // and NO_PROXY (or the lowercase versions thereof). var DefaultTransport RoundTripper = &Transport{ //Proxy: ProxyFromEnvironment, - Proxy: nil, + Proxy: nil, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, } // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -37,6 +40,12 @@ const DefaultMaxIdleConnsPerHost = 2 const debugSwitch = true type Transport struct { + idleMu sync.Mutex + closeIdle bool // user has requested to close all idle conns + idleConn map[connectMethodKey][]*persistConn // most recently used at end + idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns + idleLRU connLRU + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) @@ -63,6 +72,15 @@ type Transport struct { // uncompressed. DisableCompression bool + // MaxIdleConns controls the maximum number of idle (keep-alive) + // connections across all hosts. Zero means no limit. + MaxIdleConns int + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) connections to keep per-host. If zero, + // DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int + // MaxConnsPerHost optionally limits the total number of // connections per host, including connections in the dialing, // active, and idle states. On limit violation, dials will block. @@ -70,9 +88,16 @@ type Transport struct { // Zero means no limit. MaxConnsPerHost int + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // libuv and hyper related loopInitOnce sync.Once loop *libuv.Loop + async *libuv.Async exec *hyper.Executor } @@ -181,14 +206,258 @@ func (tr *transportRequest) setError(err error) { tr.mu.Unlock() } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - cm.targetScheme = treq.URL.Scheme - cm.targetAddr = canonicalAddr(treq.URL) - if t.Proxy != nil { - cm.proxyURL, err = t.Proxy(treq.Request) +func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { + if err := t.tryPutIdleConn(pconn); err != nil { + pconn.close(err) + } +} + +func (t *Transport) maxIdleConnsPerHost() int { + if v := t.MaxIdleConnsPerHost; v != 0 { + return v + } + return DefaultMaxIdleConnsPerHost +} + +// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting +// a new request. +// If pconn is no longer needed or not in a good state, tryPutIdleConn returns +// an error explaining why it wasn't registered. +// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that. +func (t *Transport) tryPutIdleConn(pconn *persistConn) error { + if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { + return errKeepAlivesDisabled + } + if pconn.isBroken() { + return errConnBroken + } + pconn.markReused() + + t.idleMu.Lock() + defer t.idleMu.Unlock() + + // HTTP/2 (pconn.alt != nil) connections do not come out of the idle list, + // because multiple goroutines can use them simultaneously. + // If this is an HTTP/2 connection being “returned,” we're done. + if pconn.alt != nil && t.idleLRU.m[pconn] != nil { + return nil + } + + // Deliver pconn to goroutine waiting for idle connection, if any. + // (They may be actively dialing, but this conn is ready first. + // Chrome calls this socket late binding. + // See https://www.chromium.org/developers/design-documents/network-stack#TOC-Connection-Management.) + key := pconn.cacheKey + if q, ok := t.idleConnWait[key]; ok { + done := false + if pconn.alt == nil { + // HTTP/1. + // Loop over the waiting list until we find a w that isn't done already, and hand it pconn. + for q.len() > 0 { + w := q.popFront() + if w.tryDeliver(pconn, nil) { + done = true + break + } + } + } else { + // HTTP/2. + // Can hand the same pconn to everyone in the waiting list, + // and we still won't be done: we want to put it in the idle + // list unconditionally, for any future clients too. + for q.len() > 0 { + w := q.popFront() + w.tryDeliver(pconn, nil) + } + } + if q.len() == 0 { + delete(t.idleConnWait, key) + } else { + t.idleConnWait[key] = q + } + if done { + return nil + } + } + + if t.closeIdle { + return errCloseIdle + } + if t.idleConn == nil { + t.idleConn = make(map[connectMethodKey][]*persistConn) + } + idles := t.idleConn[key] + if len(idles) >= t.maxIdleConnsPerHost() { + return errTooManyIdleHost + } + for _, exist := range idles { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } + t.idleConn[key] = append(idles, pconn) + t.idleLRU.add(pconn) + if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { + oldest := t.idleLRU.removeOldest() + oldest.close(errTooManyIdle) + t.removeIdleConnLocked(oldest) + } + + // Set idle timer, but only for HTTP/1 (pconn.alt == nil). + // The HTTP/2 implementation manages the idle timer itself + // (see idleConnTimeout in h2_bundle.go). + idleConnTimeout := uint64(t.IdleConnTimeout.Milliseconds()) + if t.IdleConnTimeout > 0 && pconn.alt == nil { + if pconn.idleTimer != nil { + pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) + } else { + pconn.idleTimer = &libuv.Timer{} + libuv.InitTimer(t.loop, pconn.idleTimer) + (*libuv.Handle)(c.Pointer(pconn.idleTimer)).SetData(c.Pointer(pconn)) + pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) + } + } + pconn.idleAt = time.Now() + return nil +} + +func onIdleConnTimeout(timer *libuv.Timer) { + pconn := (*persistConn)((*libuv.Handle)(c.Pointer(timer)).GetData()) + isClose := pconn.closeConnIfStillIdle() + if isClose { + timer.Stop() + } else { + timer.Start(onIdleConnTimeout, 0, 0) } - cm.onlyH1 = treq.requiresHTTP1() - return cm, err +} + +// queueForIdleConn queues w to receive the next idle connection for w.cm. +// As an optimization hint to the caller, queueForIdleConn reports whether +// it successfully delivered an already-idle connection. +func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { + if t.DisableKeepAlives { + return false + } + + t.idleMu.Lock() + defer t.idleMu.Unlock() + + // Stop closing connections that become idle - we might want one. + // (That is, undo the effect of t.CloseIdleConnections.) + t.closeIdle = false + + if w == nil { + // Happens in test hook. + return false + } + + // If IdleConnTimeout is set, calculate the oldest + // persistConn.idleAt time we're willing to use a cached idle + // conn. + var oldTime time.Time + if t.IdleConnTimeout > 0 { + oldTime = time.Now().Add(-t.IdleConnTimeout) + } + // Look for most recently-used idle connection. + if list, ok := t.idleConn[w.key]; ok { + stop := false + delivered := false + for len(list) > 0 && !stop { + pconn := list[len(list)-1] + + // See whether this connection has been idle too long, considering + // only the wall time (the Round(0)), in case this is a laptop or VM + // coming out of suspend with previously cached idle connections. + tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + if tooOld { + // Async cleanup. Launch in its own goroutine (as if a + // time.AfterFunc called it); it acquires idleMu, which we're + // holding, and does a synchronous net.Conn.Close. + pconn.closeConnIfStillIdleLocked() + } + if pconn.isBroken() || tooOld { + // If either persistConn.readLoop has marked the connection + // broken, but Transport.removeIdleConn has not yet removed it + // from the idle list, or if this persistConn is too old (it was + // idle too long), then ignore it and look for another. In both + // cases it's already in the process of being closed. + list = list[:len(list)-1] + continue + } + delivered = w.tryDeliver(pconn, nil) + if delivered { + if pconn.alt != nil { + // HTTP/2: multiple clients can share pconn. + // Leave it in the list. + } else { + // HTTP/1: only one client can use pconn. + // Remove it from the list. + t.idleLRU.remove(pconn) + list = list[:len(list)-1] + } + } + stop = true + } + if len(list) > 0 { + t.idleConn[w.key] = list + } else { + delete(t.idleConn, w.key) + } + if stop { + return delivered + } + } + + // Register to receive next connection that becomes idle. + if t.idleConnWait == nil { + t.idleConnWait = make(map[connectMethodKey]wantConnQueue) + } + q := t.idleConnWait[w.key] + q.cleanFront() + q.pushBack(w) + t.idleConnWait[w.key] = q + return false +} + +// removeIdleConn marks pconn as dead. +func (t *Transport) removeIdleConn(pconn *persistConn) bool { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return t.removeIdleConnLocked(pconn) +} + +// t.idleMu must be held. +func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { + if pconn.idleTimer != nil { + pconn.idleTimer.Stop() + (*libuv.Handle)(c.Pointer(pconn.idleTimer)).Close(nil) + } + t.idleLRU.remove(pconn) + key := pconn.cacheKey + pconns := t.idleConn[key] + var removed bool + switch len(pconns) { + case 0: + // Nothing + case 1: + if pconns[0] == pconn { + delete(t.idleConn, key) + removed = true + } + default: + for i, v := range pconns { + if v != pconn { + continue + } + // Slide down, keeping most recently-used + // conns at the end. + copy(pconns[i:], pconns[i+1:]) + t.idleConn[key] = pconns[:len(pconns)-1] + removed = true + break + } + } + return removed } func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { @@ -223,6 +492,16 @@ func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { return true } +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) + } + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + // alternateRoundTripper returns the alternate RoundTripper to use // for this request if the Request's URL scheme requires one, // or nil for the normal case of using the Transport. @@ -286,6 +565,9 @@ func (t *Transport) closeLocked(err error) { if t.loop != nil { t.loop.Close() } + if t.async != nil { + t.async.Close(nil) + } if t.exec != nil { t.exec.Free() } @@ -308,13 +590,12 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { defer println("RoundTrip end") } t.loopInitOnce.Do(func() { + println("init loop") t.loop = libuv.LoopNew() + t.async = &libuv.Async{} t.exec = hyper.NewExecutor() - //idle := &libuv.Idle{} - //libuv.InitIdle(t.loop, idle) - //(*libuv.Handle)(c.Pointer(idle)).SetData(c.Pointer(t)) - //idle.Start(readWriteLoop) + t.loop.Async(t.async, nil) checker := &libuv.Check{} libuv.InitCheck(t.loop, checker) @@ -330,7 +611,6 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // Only the first request will initialize the timer if req.timer == nil && !req.deadline.IsZero() { req.timer = &libuv.Timer{} - req.timeoutch = make(chan struct{}, 1) libuv.InitTimer(t.loop, req.timer) ch := &timeoutData{ timeoutch: req.timeoutch, @@ -473,9 +753,10 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest + t.setReqCanceler(cancelKey, nil) // HTTP/2 not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { + // HTTP/1.X path. resp, err = pconn.roundTrip(treq) } @@ -485,8 +766,35 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) Retry & ConnPool(t.doRoundTrip) - return nil, err + // TODO(spongehah) ConnPool(t.doRoundTrip) + if http2isNoCachedConnError(err) { + if t.removeIdleConn(pconn) { + t.decConnsPerHost(pconn.cacheKey) + } + } else if !pconn.shouldRetryRequest(req, err) { + // Issue 16465: return underlying net.Conn.Read error from peek, + // as we've historically done. + if e, ok := err.(nothingWrittenError); ok { + err = e.error + } + if e, ok := err.(transportReadFromServerError); ok { + err = e.err + } + if b, ok := req.Body.(*readTrackingBody); ok && !b.didClose { + // Issue 49621: Close the request body if pconn.roundTrip + // didn't do so already. This can happen if the pconn + // write loop exits without reading the write request. + req.closeBody() + } + return nil, err + } + testHookRoundTripRetried() + + // Rewind the body if we're able to. + req, err = rewindBody(req) + if err != nil { + return nil, err + } } } @@ -507,7 +815,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi key: cm.key(), //ctx: ctx, timeoutch: treq.timeoutch, - ready: make(chan struct{}, 1), beforeDial: testHookPrePendingDial, afterDial: testHookPostPendingDial, } @@ -518,20 +825,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi }() // TODO(spongehah) ConnPool(t.getConn) - //// Queue for idle connection. - //if delivered := t.queueForIdleConn(w); delivered { - // pc := w.pc - // // Trace only for HTTP/1. - // // HTTP/2 calls trace.GotConn itself. - // if pc.alt == nil && trace != nil && trace.GotConn != nil { - // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) - // } - // // set request canceler to some non-nil function so we - // // can detect whether it was cleared between now and when - // // we enter roundTrip - // t.setReqCanceler(treq.cancelKey, func(error) {}) - // return pc, nil - //} + // Queue for idle connection. + if delivered := t.queueForIdleConn(w); delivered { + pc := w.pc + // Trace only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + // TODO(spongehah) trace(t.getConn) + //if pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) + //} + // set request canceler to some non-nil function so we + // can detect whether it was cleared between now and when + // we enter roundTrip + t.setReqCanceler(treq.cancelKey, func(error) {}) + return pc, nil + } cancelc := make(chan error, 1) t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) @@ -539,52 +847,36 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Queue for permission to dial. t.queueForDial(w) - // Wait for completion or cancellation. - select { - case <-w.ready: - // Trace success but only for HTTP/1. - // HTTP/2 calls trace.GotConn itself. - //if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { - // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) - //} - if w.err != nil { - // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the - // cancellation error (see golang.org/issue/16049). - select { - // TODO(spongehah) cancel(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() - case <-req.timeoutch: - return nil, errors.New("timeout: req.Context().Err()") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err - default: - // return below + // Trace success but only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + //if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) + //} + if w.err != nil { + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + // TODO(spongehah) timeout(t.getConn) + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case <-req.timeoutch: + if debugSwitch { + println("getConn: timeoutch") } + return nil, errors.New("timeout: req.Context().Err()") + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + default: + // return below } - return w.pc, w.err - // TODO(spongehah) cancel(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, - case <-req.timeoutch: - if debugSwitch { - println("getConn: timeoutch") - } - return nil, errors.New("timeout: req.Context().Err()\n") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err } + return w.pc, w.err } // queueForDial queues w to wait for permission to begin dialing. @@ -597,7 +889,7 @@ func (t *Transport) queueForDial(w *wantConn) { w.beforeDial() if t.MaxConnsPerHost <= 0 { - go t.dialConnFor(w) + t.dialConnFor(w) return } @@ -609,7 +901,7 @@ func (t *Transport) queueForDial(w *wantConn) { t.connsPerHost = make(map[connectMethodKey]int) } t.connsPerHost[w.key] = n + 1 - go t.dialConnFor(w) + t.dialConnFor(w) return } @@ -633,17 +925,16 @@ func (t *Transport) dialConnFor(w *wantConn) { defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) - w.tryDeliver(pc, err) // TODO(spongehah) ConnPool(t.dialConnFor) - //delivered := w.tryDeliver(pc, err) - // Handle undelivered or shareable connections - //if err == nil && (!delivered || pc.alt != nil) { - // // pconn was not passed to w, - // // or it is HTTP/2 and can be shared. - // // Add to the idle connection pool. - // t.putOrCloseIdleConn(pc) - //} - + delivered := w.tryDeliver(pc, err) + // If the connection was successfully established but was not passed to w, + // or is a shareable HTTP/2 connection + if err == nil && (!delivered || pc.alt != nil) { + // pconn was not passed to w, + // or it is HTTP/2 and can be shared. + // Add to the idle connection pool. + t.putOrCloseIdleConn(pc) + } // If an error occurs during the dialing process, the connection count for that host is decreased. // This ensures that the connection count remains accurate even in cases where the dial attempt fails. if err != nil { @@ -676,7 +967,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { for q.len() > 0 { w := q.popFront() if w.waiting() { - go t.dialConnFor(w) + t.dialConnFor(w) done = true break } @@ -708,6 +999,7 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * } select { case <-timeoutch: + err = errors.New("[t.dialConn] request timeout") return default: } @@ -716,6 +1008,7 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * cacheKey: cm.key(), closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), + alive: true, } //trace := httptrace.ContextClientTrace(ctx) @@ -754,7 +1047,7 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // } //} else { //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(timeoutch, cm.addr()) + conn, err := t.dial(cm.addr()) if err != nil { return nil, err } @@ -811,14 +1104,15 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * select { case <-timeoutch: - conn.Close() - return + err = errors.New("[t.dialConn] request timeout") + pconn.close(err) + return nil, err default: } return pconn, nil } -func (t *Transport) dial(timeoutch chan struct{}, addr string) (*connData, error) { +func (t *Transport) dial(addr string) (*connData, error) { if debugSwitch { println("dial start") defer println("dial end") @@ -862,7 +1156,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { // TODO(spongehah) ConnPool(pc.roundTrip) - //pc.t.putOrCloseIdleConn(pc) + pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } pc.mu.Lock() @@ -925,16 +1219,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err writeErrCh := make(chan error, 1) resc := make(chan responseAndError, 1) - // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) - // We need an executor generally to poll futures - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(pc.t.exec) - // send the handshake - handshakeTask := hyper.Handshake(hyperIo, opts) taskData := &taskData{ - taskId: write, req: req, pc: pc, addedGzip: requestedGzip, @@ -942,14 +1227,32 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err callerGone: gone, resc: resc, } - handshakeTask.SetUserdata(c.Pointer(taskData)) - // Send the request to readWriteLoop(). - // Let's wait for the handshake to finish... - pc.t.exec.Push(handshakeTask) - async := &libuv.Async{} - pc.t.loop.Async(async, asyncCb) - async.Send() + if pc.client == nil && !pc.isReused() { + println("first") + // Hookup the IO + hyperIo := newIoWithConnReadWrite(pc.conn) + // We need an executor generally to poll futures + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(pc.t.exec) + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + taskData.taskId = handshake + handshakeTask.SetUserdata(c.Pointer(taskData)) + // Send the request to readWriteLoop(). + pc.t.exec.Push(handshakeTask) + } else { + println("second") + taskData.taskId = read + err = req.write(pc.client, taskData, pc.t.exec) + if err != nil { + writeErrCh <- err + } + } + + // Wake up libuv. Loop + pc.t.async.Send() //var respHeaderTimer <-chan time.Time //cancelChan := req.Request.Cancel @@ -1003,7 +1306,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - // TODO(spongehah) cancel(pc.roundTrip) + // TODO(spongehah) timeout(pc.roundTrip) //case <-cancelChan: // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) // cancelChan = nil @@ -1022,20 +1325,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } } -func asyncCb(async *libuv.Async) { - println("async called") -} - // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. -func readWriteLoop(idle *libuv.Check) { - println("polling") - t := (*Transport)((*libuv.Handle)(c.Pointer(idle)).GetData()) +func readWriteLoop(checker *libuv.Check) { + t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) // Read this once, before loop starts. (to avoid races in tests) - testHookMu.Lock() - testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - testHookMu.Unlock() + //testHookMu.Lock() + //testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + //testHookMu.Unlock() const debugReadWriteLoop = true // Debug switch provided for developers @@ -1043,6 +1341,9 @@ func readWriteLoop(idle *libuv.Check) { // Poll all ready tasks and act on them... for { task := t.exec.Poll() + if debugSwitch { + println("polling") + } if task == nil { return } @@ -1057,28 +1358,51 @@ func readWriteLoop(idle *libuv.Check) { println("taskId: ", taskId) } switch taskId { - case write: + case handshake: if debugReadWriteLoop { println("write") } + err := checkTaskType(task, handshake) + if err != nil { + taskData.writeErrCh <- err + task.Free() + continue + } + + pc := taskData.pc select { - case <-taskData.pc.closech: + case <-pc.closech: task.Free() continue default: } - err := checkTaskType(task, write) - client := (*hyper.ClientConn)(task.Value()) + pc.client = (*hyper.ClientConn)(task.Value()) task.Free() - if err == nil { - // TODO(spongehah) Proxy(writeLoop) - err = taskData.req.Request.write(client, taskData, t.exec) + // TODO(spongehah) Proxy(writeLoop) + taskData.taskId = read + err = taskData.req.Request.write(pc.client, taskData, t.exec) + + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + pc.close(err) + continue + } + + if debugReadWriteLoop { + println("write end") + } + case read: + if debugReadWriteLoop { + println("read") } - // For this request, no longer need the client - client.Free() + + pc := taskData.pc + + err := checkTaskType(task, read) if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -1093,59 +1417,48 @@ func readWriteLoop(idle *libuv.Check) { if err != nil { //pc.writeErrCh <- err // to the body reader, which might recycle us taskData.writeErrCh <- err // to the roundTrip function - taskData.pc.close(err) + pc.close(err) continue } - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } - - if taskData.pc.closeErr == nil { - taskData.pc.closeErr = errReadLoopExiting + if pc.closeErr == nil { + pc.closeErr = errReadLoopExiting } // TODO(spongehah) ConnPool(readWriteLoop) - //if taskData.pc.tryPutIdleConn == nil { - // //taskData.pc.tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { - // // if err := pc.t.tryPutIdleConn(pc); err != nil { - // // closeErr = err - // // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // // trace.PutIdleConn(err) - // // } - // // return false - // // } - // // if trace != nil && trace.PutIdleConn != nil { - // // trace.PutIdleConn(nil) - // // } - // // return true - // //} - //} + if pc.tryPutIdleConn == nil { + pc.tryPutIdleConn = func() bool { + if err := pc.t.tryPutIdleConn(pc); err != nil { + pc.closeErr = err + // TODO(spongehah) trace(readWriteLoop) + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + } - err := checkTaskType(task, read) + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() - taskData.pc.mu.Lock() - if taskData.pc.numExpectedResponses == 0 { - taskData.pc.closeLocked(errServerClosedIdle) - taskData.pc.mu.Unlock() + pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(hyperResp, err) + pc.mu.Unlock() // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } - taskData.pc.mu.Unlock() + pc.mu.Unlock() //trace := httptrace.ContextClientTrace(rc.req.Context()) - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() - var resp *Response var respBody *hyper.Body if err == nil { @@ -1155,7 +1468,7 @@ func readWriteLoop(idle *libuv.Check) { respBody = hyperResp.Body() } else { err = transportReadFromServerError{err} - taskData.pc.closeErr = err + pc.closeErr = err } // No longer need the response @@ -1166,21 +1479,17 @@ func readWriteLoop(idle *libuv.Check) { case taskData.resc <- responseAndError{err: err}: case <-taskData.callerGone: // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } - taskData.pc.mu.Lock() - taskData.pc.numExpectedResponses-- - taskData.pc.mu.Unlock() + pc.mu.Lock() + pc.numExpectedResponses-- + pc.mu.Unlock() bodyWritable := resp.bodyIsWritable() hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 @@ -1189,46 +1498,43 @@ func readWriteLoop(idle *libuv.Check) { // Don't do keep-alive on error if either party requested a close // or we get an unexpected informational (1xx) response. // StatusCode 100 is already handled above. - taskData.pc.alive = false + pc.alive = false } if !hasBody || bodyWritable { - //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) - t.replaceReqCanceler(taskData.req.cancelKey, nil) + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // TODO(spongehah) ConnPool(readWriteLoop) - //// Put the idle conn back into the pool before we send the response - //// so if they process it quickly and make another request, they'll - //// get this same conn. But we use the unbuffered channel 'rc' - //// to guarantee that persistConn.roundTrip got out of its select - //// potentially waiting for this persistConn to close. - //taskData.pc.alive = taskData.pc.alive && + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() + //pc.alive = pc.alive && // !pc.sawEOF && // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) + // replaced && pc.tryPutIdleConn() if bodyWritable { - taskData.pc.closeErr = errCallerOwnsConn + pc.closeErr = errCallerOwnsConn } select { case taskData.resc <- responseAndError{res: resp}: case <-taskData.callerGone: // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) - testHookReadLoopBeforeNextRead() - if taskData.pc.alive == false { + //testHookReadLoopBeforeNextRead() + if pc.alive == false { // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) } continue } @@ -1242,7 +1548,7 @@ func readWriteLoop(idle *libuv.Check) { fn: func(err error) error { isEOF := err == io.EOF if !isEOF { - if cerr := taskData.pc.canceled(); cerr != nil { + if cerr := pc.canceled(); cerr != nil { return cerr } } @@ -1265,24 +1571,22 @@ func readWriteLoop(idle *libuv.Check) { taskData.taskId = readDone bodyForeachTask.SetUserdata(c.Pointer(taskData)) t.exec.Push(bodyForeachTask) - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + if taskData.req.timer != nil { + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + } // TODO(spongehah) select blocking(readWriteLoop) //select { //case taskData.resc <- responseAndError{res: resp}: //case <-taskData.callerGone: // // defer - // taskData.pc.close(taskData.pc.closeErr) - // // TODO(spongehah) ConnPool(readWriteLoop) - // //t.removeIdleConn(pc) + // readLoopDefer(pc, t) // continue //} select { case <-taskData.callerGone: // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue default: } @@ -1301,45 +1605,48 @@ func readWriteLoop(idle *libuv.Check) { } checkTaskType(task, readDone) - //bodyEOF := task.Type() == hyper.TaskEmpty + bodyEOF := task.Type() == hyper.TaskEmpty // free the task task.Free() - t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc := taskData.pc + + replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool // TODO(spongehah) ConnPool(readWriteLoop) - //taskData.pc.alive = taskData.pc.alive && + pc.alive = pc.alive && + bodyEOF && + replaced && pc.tryPutIdleConn() + //pc.alive = pc.alive && // bodyEOF && // !pc.sawEOF && // pc.wroteRequest() && // replaced && tryPutIdleConn(trace) - // TODO(spongehah) cancel(pc.readWriteLoop) + // TODO(spongehah) timeout(t.readWriteLoop) //case <-rw.rc.req.Cancel: - // taskData.pc.alive = false + // pc.alive = false // pc.t.CancelRequest(rw.rc.req) //case <-rw.rc.req.Context().Done(): - // taskData.pc.alive = false + // pc.alive = false // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - //case <-taskData.pc.closech: - // taskData.pc.alive = false + //case <-pc.closech: + // pc.alive = false //} - select { - case <-taskData.req.timeoutch: - continue - case <-taskData.pc.closech: - taskData.pc.alive = false - default: - } + //select { + //case <-taskData.req.timeoutch: + // continue + //case <-pc.closech: + // pc.alive = false + //default: + //} - if taskData.pc.alive == false { + if pc.alive == false { // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) } - testHookReadLoopBeforeNextRead() + //testHookReadLoopBeforeNextRead() if debugReadWriteLoop { println("readDone end") } @@ -1350,6 +1657,12 @@ func readWriteLoop(idle *libuv.Check) { } } +func readLoopDefer(pc *persistConn, t *Transport) { + pc.close(pc.closeErr) + // TODO(spongehah) ConnPool(readLoopDefer) + t.removeIdleConn(pc) +} + // ---------------------------------------------------------- type taskData struct { @@ -1374,6 +1687,9 @@ type connData struct { } func (conn *connData) Close() error { + if conn == nil { + return nil + } if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -1535,9 +1851,7 @@ func onTimeout(timer *libuv.Timer) { pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) // defer - pc.close(pc.closeErr) - // TODO(spongehah) ConnPool(onTimeout) - //t.removeIdleConn(pc) + readLoopDefer(pc, pc.t) } } @@ -1555,7 +1869,7 @@ type taskId c.Int const ( notSet taskId = iota - write + handshake read readDone ) @@ -1563,13 +1877,13 @@ const ( // checkTaskType checks the task type func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { - case write: + case handshake: if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::write]handshake task error!\n") + log.Printf("[readWriteLoop::handshake]handshake task error!\n") return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("[readWriteLoop::write]unexpected task type\n") + return fmt.Errorf("[readWriteLoop::handshake]unexpected task type\n") } return nil case read: @@ -1746,6 +2060,10 @@ type persistConn struct { writeLoopDone chan struct{} // closed when readWriteLoop ends + // Both guarded by Transport.idleMu: + idleAt time.Time // time it last become idle + idleTimer *libuv.Timer // holding an onIdleConnTimeout to close it + mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed @@ -1754,11 +2072,14 @@ type persistConn struct { // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) + reused bool // whether conn has had successful request/response and is being reused. mutateHeaderFunc func(Header) // other - alive bool // Replace the alive in readLoop - closeErr error // Replace the closeErr in readLoop + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop + tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop + client *hyper.ClientConn } func (pc *persistConn) cancelRequest(err error) { @@ -1779,7 +2100,18 @@ func (pc *persistConn) close(err error) { pc.closeLocked(err) } +// markReused marks this connection as having been successfully used for a +// request and response. +func (pc *persistConn) markReused() { + pc.mu.Lock() + pc.reused = true + pc.mu.Unlock() +} + func (pc *persistConn) closeLocked(err error) { + if debugSwitch { + println("pc closed") + } if err == nil { panic("nil error") } @@ -1795,6 +2127,7 @@ func (pc *persistConn) closeLocked(err error) { } close(pc.closech) close(pc.writeLoopDone) + pc.client.Free() } } pc.mutateHeaderFunc = nil @@ -1866,6 +2199,14 @@ func (pc *persistConn) canceled() error { return pc.canceledErr } +// isReused reports whether this connection has been used before. +func (pc *persistConn) isReused() bool { + pc.mu.Lock() + r := pc.reused + pc.mu.Unlock() + return r +} + // isBroken reports whether this connection is in a known broken state. func (pc *persistConn) isBroken() bool { pc.mu.Lock() @@ -1874,6 +2215,107 @@ func (pc *persistConn) isBroken() bool { return b } +// shouldRetryRequest reports whether we should retry sending a failed +// HTTP request on a new connection. The non-nil input error is the +// error from roundTrip. +func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { + if http2isNoCachedConnError(err) { + // Issue 16582: if the user started a bunch of + // requests at once, they can all pick the same conn + // and violate the server's max concurrent streams. + // Instead, match the HTTP/1 behavior for now and dial + // again to get a new TCP connection, rather than failing + // this request. + return true + } + if err == errMissingHost { + // User error. + return false + } + if !pc.isReused() { + // This was a fresh connection. There's no reason the server + // should've hung up on us. + // + // Also, if we retried now, we could loop forever + // creating new connections and retrying if the server + // is just hanging up on us because it doesn't like + // our request (as opposed to sending an error). + return false + } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry, if there's no body or we + // can "rewind" the body with GetBody. + return req.outgoingLength() == 0 || req.GetBody != nil + } + if !req.isReplayable() { + // Don't retry non-idempotent requests. + return false + } + if _, ok := err.(transportReadFromServerError); ok { + // We got some non-EOF net.Conn.Read failure reading + // the 1st response byte from the server. + return true + } + if err == errServerClosedIdle { + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + return true + } + return false // conservatively +} + +// closeConnIfStillIdle closes the connection if it's still sitting idle. +// This is what's called by the persistConn's idleTimer, and is run in its +// own goroutine. +func (pc *persistConn) closeConnIfStillIdle() bool { + t := pc.t + isLock := t.idleMu.TryLock() + if isLock { + defer t.idleMu.Unlock() + pc.closeConnIfStillIdleLocked() + return true + } + return false +} + +func (pc *persistConn) closeConnIfStillIdleLocked() { + t := pc.t + if _, ok := t.idleLRU.m[pc]; !ok { + // Not idle. + return + } + t.removeIdleConnLocked(pc) + pc.close(errIdleConnTimeout) +} + +func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { + if pc.closed != nil { + return + } + if is408Message(resp) { + pc.closeLocked(errServerClosedIdle) + return + } + pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", err)) +} + +func is408Message(resp *hyper.Response) bool { + httpVersion := int(resp.Version()) + if httpVersion != 10 && httpVersion != 11 { + return false + } + return resp.Status() == 408 +} + +// isNoCachedConnError reports whether err is of type noCachedConnError +// or its equivalent renamed type in net/http2's h2_bundle.go. Both types +// may coexist in the same running program. +func http2isNoCachedConnError(err error) bool { // h2_bundle.go + _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) + return ok +} + // connectMethod is the map key (in its String form) for keeping persistent // TCP connections alive for subsequent HTTP requests. // @@ -1967,7 +2409,8 @@ type wantConn struct { key connectMethodKey // cm.key() ctx context.Context // context for dial timeoutch chan struct{} // tmp timeout to replace ctx - ready chan struct{} // closed when pc, err pair is delivered + ready bool + //ready chan struct{} // closed when pc, err pair is delivered // hooks for testing to know when dials are done // beforeDial is called in the getConn goroutine when the dial is queued. @@ -1985,25 +2428,24 @@ type wantConn struct { func (w *wantConn) cancel(t *Transport, err error) { w.mu.Lock() if w.pc == nil && w.err == nil { - close(w.ready) // catch misbehavior in future delivery + w.ready = true // catch misbehavior in future delivery } - //pc := w.pc + pc := w.pc w.pc = nil w.err = err w.mu.Unlock() // TODO(spongehah) ConnPool(w.cancel) - //if pc != nil { - // t.putOrCloseIdleConn(pc) - //} + if pc != nil { + t.putOrCloseIdleConn(pc) + } } // waiting reports whether w is still waiting for an answer (connection or error). func (w *wantConn) waiting() bool { - select { - case <-w.ready: + if w.ready { return false - default: + } else { return true } } @@ -2022,12 +2464,7 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { if w.pc == nil && w.err == nil { panic("net/http: internal error: misuse of tryDeliver") } - select { - case <-w.timeoutch: - pc.close(errors.New("request timeout: dialConn timeout")) - default: - } - close(w.ready) + w.ready = true return true } @@ -2200,3 +2637,42 @@ func (gz *gzipReader) Read(p []byte) (n int, err error) { func (gz *gzipReader) Close() error { return gz.body.Close() } + +type connLRU struct { + ll *list.List // list.Element.Value type is of *persistConn + m map[*persistConn]*list.Element +} + +// add adds pc to the head of the linked list. +func (cl *connLRU) add(pc *persistConn) { + if cl.ll == nil { + cl.ll = list.New() + cl.m = make(map[*persistConn]*list.Element) + } + ele := cl.ll.PushFront(pc) + if _, ok := cl.m[pc]; ok { + panic("persistConn was already in LRU") + } + cl.m[pc] = ele +} + +func (cl *connLRU) removeOldest() *persistConn { + ele := cl.ll.Back() + pc := ele.Value.(*persistConn) + cl.ll.Remove(ele) + delete(cl.m, pc) + return pc +} + +// remove removes pc from cl. +func (cl *connLRU) remove(pc *persistConn) { + if ele, ok := cl.m[pc]; ok { + cl.ll.Remove(ele) + delete(cl.m, pc) + } +} + +// len returns the number of items in the cache. +func (cl *connLRU) len() int { + return len(cl.m) +} diff --git a/x/net/http/util.go b/x/net/http/util.go index bec22a8..bfd9fc3 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -7,7 +7,7 @@ import ( "golang.org/x/net/idna" - "github.com/goplus/llgo/x/net" + "github.com/goplus/llgoexamples/x/net" ) /** From e55b2616a086b5ef126178efab8a3f4d6884f4d4 Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Wed, 11 Sep 2024 13:22:10 +0800 Subject: [PATCH 20/21] WIP(x/net/http/client): Implement BodyChunk --- x/net/http/_demo/chunked/chunked.go | 29 + x/net/http/_demo/get/get.go | 6 +- x/net/http/_demo/headers/headers.go | 6 +- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 6 +- x/net/http/_demo/post/post.go | 5 + x/net/http/_demo/redirect/redirect.go | 6 +- x/net/http/_demo/reuseConn/reuseConn.go | 12 +- x/net/http/_demo/server/chunkedServer.go | 42 + x/net/http/_demo/upload/upload.go | 8 +- x/net/http/bodyChunk.go | 104 ++ x/net/http/client.go | 3 +- x/net/http/header.go | 70 +- x/net/http/request.go | 11 +- x/net/http/response.go | 228 +++- x/net/http/transfer.go | 57 +- x/net/http/transport.go | 1084 +++++++---------- 16 files changed, 893 insertions(+), 784 deletions(-) create mode 100644 x/net/http/_demo/chunked/chunked.go create mode 100644 x/net/http/_demo/server/chunkedServer.go create mode 100644 x/net/http/bodyChunk.go diff --git a/x/net/http/_demo/chunked/chunked.go b/x/net/http/_demo/chunked/chunked.go new file mode 100644 index 0000000..7b33c0c --- /dev/null +++ b/x/net/http/_demo/chunked/chunked.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + resp, err := http.Get("http://localhost:8080/chunked") + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) +} diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 6e91bd4..392cc72 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -15,7 +15,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index 5538923..41cc15f 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -38,7 +38,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { println(err.Error()) diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 5662251..eff95fc 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -22,7 +22,11 @@ func main() { defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/post/post.go b/x/net/http/_demo/post/post.go index fd756b3..b700028 100644 --- a/x/net/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -17,6 +17,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status) + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go index f189255..3d40f3b 100644 --- a/x/net/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -16,7 +16,11 @@ func main() { defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/reuseConn/reuseConn.go b/x/net/http/_demo/reuseConn/reuseConn.go index bb460ce..bccfe9d 100644 --- a/x/net/http/_demo/reuseConn/reuseConn.go +++ b/x/net/http/_demo/reuseConn/reuseConn.go @@ -15,7 +15,11 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) @@ -31,7 +35,11 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err = io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/server/chunkedServer.go b/x/net/http/_demo/server/chunkedServer.go new file mode 100644 index 0000000..b79ad60 --- /dev/null +++ b/x/net/http/_demo/server/chunkedServer.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "net/http" +) + +func chunkedHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Content-Type", "text/plain") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + sentence := "This is a chunked encoded response. It will be sent in multiple parts. Note the delay between each section." + + words := []string{} + start := 0 + for i, r := range sentence { + if r == '。' || r == ',' || i == len(sentence)-1 { + words = append(words, sentence[start:i+1]) + start = i + 1 + } + } + + for _, word := range words { + fmt.Fprintf(w, "%s", word) + flusher.Flush() + } +} + +func main() { + http.HandleFunc("/chunked", chunkedHandler) + fmt.Println("Starting server on :8080") + err := http.ListenAndServe(":8080", nil) + if err != nil { + fmt.Printf("Error starting server: %s\n", err) + } +} \ No newline at end of file diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index fe7256b..b5baffa 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -11,7 +11,7 @@ import ( func main() { url := "http://httpbin.org/post" //url := "http://localhost:8080" - filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path + filePath := "/Users/spongehah/Documents/code/GOPATH/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path //filePath := "/Users/spongehah/Downloads/xiaoshuo.txt" // Replace with your file path file, err := os.Open(filePath) @@ -36,7 +36,11 @@ func main() { } defer resp.Body.Close() fmt.Println("Status:", resp.Status) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } respBody, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/bodyChunk.go b/x/net/http/bodyChunk.go new file mode 100644 index 0000000..c1d1072 --- /dev/null +++ b/x/net/http/bodyChunk.go @@ -0,0 +1,104 @@ +package http + +import ( + "errors" + "io" + "sync" + + "github.com/goplus/llgo/c/libuv" +) + +type onceError struct { + sync.Mutex + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} + +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { + return &bodyChunk{ + readCh: make(chan []byte, 1), + done: make(chan struct{}), + asyncHandle: asyncHandle, + } +} + +type bodyChunk struct { + chunk []byte + readCh chan []byte + asyncHandle *libuv.Async + + once sync.Once + done chan struct{} + + rerr onceError +} + +var ( + errClosedBodyChunk = errors.New("bodyChunk: read/write on closed body") +) + +func (bc *bodyChunk) Read(p []byte) (n int, err error) { + for n < len(p) { + if len(bc.chunk) == 0 { + select { + case chunk, ok := <-bc.readCh: + if !ok { + if n > 0 { + return n, nil + } + return 0, bc.readCloseError() + } + bc.chunk = chunk + bc.asyncHandle.Send() + case <-bc.done: + if n > 0 { + return n, nil + } + return 0, io.EOF + } + } + + copied := copy(p[n:], bc.chunk) + n += copied + bc.chunk = bc.chunk[copied:] + } + + return n, nil +} + +func (bc *bodyChunk) Close() error { + return bc.closeRead(nil) +} + +func (bc *bodyChunk) readCloseError() error { + if rerr := bc.rerr.Load(); rerr != nil { + return rerr + } + return errClosedBodyChunk +} + +func (bc *bodyChunk) closeRead(err error) error { + if err == nil { + err = io.EOF + } + bc.rerr.Store(err) + bc.once.Do(func() { + close(bc.done) + }) + //close(bc.done) + return nil +} diff --git a/x/net/http/client.go b/x/net/http/client.go index 002397a..7e26395 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -307,7 +307,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) timeout(send) + // TODO(spongehah) tmp timeout(send) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) req.timeoutch = make(chan struct{}, 1) req.deadline = deadline @@ -490,7 +490,6 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return knownRoundTripperImpl(altRT, req) } return true - // TODO(spongehah) http2 //case *http2Transport, http2noDialH2RoundTripper: // return true } diff --git a/x/net/http/header.go b/x/net/http/header.go index 0d1e2cc..7c95411 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -75,15 +75,6 @@ func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } -// CanonicalHeaderKey returns the canonical format of the -// header key s. The canonicalization converts the first -// letter and any letter following a hyphen to upper case; -// the rest are converted to lowercase. For example, the -// canonical key for "accept-encoding" is "Accept-Encoding". -// If s contains a space or invalid header field bytes, it is -// returned without modifications. -func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } - // Clone returns a copy of h or nil if h is nil. func (h Header) Clone() Header { if h == nil { @@ -111,28 +102,6 @@ func (h Header) Clone() Header { return h2 } -var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") - -type keyValues struct { - key string - values []string -} - -// A headerSorter implements sort.Interface by sorting a []keyValues -// by key. It's used as a pointer, so it can fit in a sort.Interface -// interface value without allocation. -type headerSorter struct { - kvs []keyValues -} - -func (s *headerSorter) Len() int { return len(s.kvs) } -func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } -func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } - -var headerSorterPool = sync.Pool{ - New: func() any { return new(headerSorter) }, -} - // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. @@ -199,6 +168,37 @@ func (h Header) writeSubset(reqHeaders *hyper.Headers, exclude map[string]bool) return nil } +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + // hasToken reports whether token appears with v, ASCII // case-insensitive, with space or comma boundaries. // token must be all lowercase. @@ -251,11 +251,3 @@ func appendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va resp.Header.Add(nameStr, valueStr) return hyper.IterContinue } - -func (resp *Response) PrintHeaders() { - for key, values := range resp.Header { - for _, value := range values { - fmt.Printf("%s: %s\n", key, value) - } - } -} diff --git a/x/net/http/request.go b/x/net/http/request.go index c5146ed..e9279fc 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -294,7 +294,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype //} // Prepare the hyper.Request - hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra) + hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra, taskData.req) if err != nil { return err } @@ -308,7 +308,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype return err } -func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.Request, error) { +func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header, treq *transportRequest) (*hyper.Request, error) { // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -401,11 +401,6 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Process Body,ContentLength,Close,Trailer - //tw, err := newTransferWriter(r) - //if err != nil { - // return err - //} - //err = tw.writeHeader(w, trace) err = r.writeHeader(reqHeaders) if err != nil { return nil, err @@ -433,7 +428,7 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Write body and trailer - err = r.writeBody(hyperReq) + err = r.writeBody(hyperReq, treq) if err != nil { return nil, err } diff --git a/x/net/http/response.go b/x/net/http/response.go index 6ff5b3d..a3a96fc 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -1,10 +1,12 @@ package http import ( + "compress/gzip" + "errors" "fmt" "io" "strconv" - "unsafe" + "sync" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -32,7 +34,198 @@ func (r *Response) closeBody() { } } -func ReadResponse(r *io.PipeReader, req *Request, hyperResp *hyper.Response) (*Response, error) { +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + +func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { + pc := taskData.pc + bodyWritable := r.bodyIsWritable() + hasBody := taskData.req.Method != "HEAD" && r.ContentLength != 0 + + if r.Close || taskData.req.Close || r.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + pc.alive = false + } + + if !hasBody || bodyWritable { + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) + + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() + + if bodyWritable { + pc.closeErr = errCallerOwnsConn + } + + select { + case taskData.resc <- responseAndError{res: r}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return true + } + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + readLoopDefer(pc, false) + return true + } + return false +} + +func (r *Response) wrapRespBody(taskData *taskData) { + body := &bodyEOFSignal{ + body: r.Body, + earlyCloseFn: func() error { + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + if !isEOF { + if cerr := taskData.pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + r.Body = body + // TODO(spongehah) gzip(wrapRespBody) + //if taskData.addedGzip && EqualFold(r.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // r.Body = &gzipReader{body: body} + // r.Header.Del("Content-Encoding") + // r.Header.Del("Content-Length") + // r.ContentLength = -1 + // r.Uncompressed = true + //} +} + +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. fn should return the new error to return from Read or Close. +// +// If earlyCloseFn is non-nil and Close is called before io.EOF is +// seen, earlyCloseFn is called instead of fn, and its return value is +// the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) error // err will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errReadOnClosedResBody + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + err = es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + return es.condfn(err) +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) error { + if es.fn == nil { + return err + } + err = es.fn(err) + es.fn = nil + return err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + _ incomparable + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) + } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + +func ReadResponse(r io.ReadCloser, req *Request, hyperResp *hyper.Response) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), @@ -65,20 +258,6 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } -// appendToResponseBody BodyForeachCallback function: Process the response body -func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - writer := (*io.PipeWriter)(userdata) - bufLen := chunk.Len() - bytes := unsafe.Slice(chunk.Bytes(), bufLen) - _, err := writer.Write(bytes) - if err != nil { - fmt.Println("Error writing to response body:", err) - writer.Close() - return hyper.IterBreak - } - return hyper.IterContinue -} - // RFC 7234, section 5.4: Should treat // // Pragma: no-cache @@ -94,26 +273,9 @@ func fixPragmaCacheControl(header Header) { } } -// Cookies parses and returns the cookies set in the Set-Cookie headers. -func (r *Response) Cookies() []*Cookie { - return readSetCookies(r.Header) -} - // isProtocolSwitchHeader reports whether the request or response header // is for a protocol switch. func isProtocolSwitchHeader(h Header) bool { return h.Get("Upgrade") != "" && HeaderValuesContainsToken(h["Connection"], "Upgrade") } - -// bodyIsWritable reports whether the Body supports writing. The -// Transport returns Writable bodies for 101 Switching Protocols -// responses. -// The Transport uses this method to determine whether a persistent -// connection is done being managed from its perspective. Once we -// return a writable response body to a user, the net/http package is -// done managing that connection. -func (r *Response) bodyIsWritable() bool { - _, ok := r.Body.(io.Writer) - return ok -} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 103200c..818fb3c 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -98,7 +98,7 @@ func (uste *unsupportedTEError) Error() string { } // msg is *Request or *Response. -func readTransfer(msg any, r *io.PipeReader) (err error) { +func readTransfer(msg any, r io.ReadCloser) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -173,19 +173,17 @@ func readTransfer(msg any, r *io.PipeReader) (err error) { if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { t.Body = NoBody } else { - // TODO(spongehah) ChunkReader(readTransfer) - //t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} - t.Body = &body{src: r, hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: r, closer: r, hdr: msg, r: r, closing: t.Close} } case realLength == 0: t.Body = NoBody case realLength > 0: - t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + t.Body = &body{src: io.LimitReader(r, realLength), closer: r, closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) - t.Body = &body{src: r, closing: t.Close} + t.Body = &body{src: r, closer: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) t.Body = NoBody @@ -349,9 +347,9 @@ func fixTrailer(header Header, chunked bool) (Header, error) { // Close ensures that the body has been fully read // and then reads the trailer if necessary. type body struct { - src io.Reader - hdr any // non-nil (Response or Request) value means read trailer - //r *bufio.Reader // underlying wire-format reader for the trailer + src io.Reader + closer io.Closer + hdr any // non-nil (Response or Request) value means read trailer r io.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early @@ -476,6 +474,15 @@ func (b *body) Close() error { _, err = io.Copy(io.Discard, bodyLocked{b}) } b.closed = true + + // Close bodyChunk + if b.closer != nil { + closeErr := b.closer.Close() + if err == nil { + err = closeErr + } + } + return err } @@ -654,26 +661,26 @@ func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) // files (*os.File types) are properly optimized. // // This function is only intended for use in writeBody. -func (req *Request) unwrapBody() io.Reader { - if r, ok := unwrapNopCloser(req.Body); ok { +func (r *Request) unwrapBody() io.Reader { + if r, ok := unwrapNopCloser(r.Body); ok { return r } - if r, ok := req.Body.(*readTrackingBody); ok { + if r, ok := r.Body.(*readTrackingBody); ok { r.didRead = true return r.ReadCloser } - return req.Body + return r.Body } -func (r *Request) writeBody(hyperReq *hyper.Request) error { +func (r *Request) writeBody(hyperReq *hyper.Request, treq *transportRequest) error { if r.Body != nil { var body = r.unwrapBody() hyperReqBody := hyper.NewBody() buf := make([]byte, defaultChunkSize) reqData := &bodyReq{ - body: body, - buf: buf, - closeBody: r.closeBody, + body: body, + buf: buf, + treq: treq, } hyperReqBody.SetUserdata(c.Pointer(reqData)) hyperReqBody.SetDataFunc(setPostData) @@ -683,9 +690,9 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { } type bodyReq struct { - body io.Reader - buf []byte - closeBody func() error + body io.Reader + buf []byte + treq *transportRequest } func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { @@ -694,10 +701,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In if err != nil { if err == io.EOF { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } fmt.Println("error reading request body: ", err) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } if n > 0 { @@ -706,10 +714,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In } if n == 0 { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } - req.closeBody() - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.closeBody() + err = fmt.Errorf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 44d721d..8075133 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -1,7 +1,6 @@ package http import ( - "compress/gzip" "container/list" "context" "errors" @@ -37,7 +36,12 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -const debugSwitch = true + +// Debug switch provided for developers +const ( + debugSwitch = true + debugReadWriteLoop = true +) type Transport struct { idleMu sync.Mutex @@ -46,53 +50,24 @@ type Transport struct { idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns idleLRU connLRU - altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - Proxy func(*Request) (*url.URL, error) + + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns - // DisableKeepAlives, if true, disables HTTP keep-alives and - // will only use the connection to the server for a single - // HTTP request. - // - // This is unrelated to the similarly named TCP keep-alives. - DisableKeepAlives bool - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool + Proxy func(*Request) (*url.URL, error) - // MaxIdleConns controls the maximum number of idle (keep-alive) - // connections across all hosts. Zero means no limit. - MaxIdleConns int + DisableKeepAlives bool + DisableCompression bool - // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) connections to keep per-host. If zero, - // DefaultMaxIdleConnsPerHost is used. + MaxIdleConns int MaxIdleConnsPerHost int - - // MaxConnsPerHost optionally limits the total number of - // connections per host, including connections in the dialing, - // active, and idle states. On limit violation, dials will block. - // - // Zero means no limit. - MaxConnsPerHost int - - // IdleConnTimeout is the maximum amount of time an idle - // (keep-alive) connection will remain idle before closing - // itself. - // Zero means no limit. - IdleConnTimeout time.Duration + MaxConnsPerHost int + IdleConnTimeout time.Duration // libuv and hyper related loopInitOnce sync.Once @@ -516,14 +491,11 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { // useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false - } - return true + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return !(req.URL.Scheme == "https" && req.requiresHTTP1()) } // CancelRequest cancels an in-flight request by closing its connection. @@ -573,6 +545,8 @@ func (t *Transport) closeLocked(err error) { } } +// ---------------------------------------------------------- + func getMilliseconds(deadline time.Time) uint64 { microseconds := deadline.Sub(time.Now()).Microseconds() milliseconds := microseconds / 1e3 @@ -582,15 +556,13 @@ func getMilliseconds(deadline time.Time) uint64 { return uint64(milliseconds) } -// ---------------------------------------------------------- - func (t *Transport) RoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("RoundTrip start") - defer println("RoundTrip end") + println("############### RoundTrip start") + defer println("############### RoundTrip end") } t.loopInitOnce.Do(func() { - println("init loop") + println("############### init loop") t.loop = libuv.LoopNew() t.async = &libuv.Async{} t.exec = hyper.NewExecutor() @@ -620,7 +592,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) if debugSwitch { - println("timer start") + println("############### timer start") } didTimeout = func() bool { return req.timer.GetDueIn() == 0 } stopTimer = func() { @@ -628,7 +600,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Stop() (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) if debugSwitch { - println("timer close") + println("############### timer close") } } } else { @@ -654,8 +626,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { func (t *Transport) doRoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("doRoundTrip start") - defer println("doRoundTrip end") + println("############### doRoundTrip start") + defer println("############### doRoundTrip end") } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() @@ -715,7 +687,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout(t.doRoundTrip) //select { //case <-ctx.Done(): // req.closeBody() @@ -766,7 +737,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) ConnPool(t.doRoundTrip) if http2isNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) @@ -800,8 +770,8 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { if debugSwitch { - println("getConn start") - defer println("getConn end") + println("############### getConn start") + defer println("############### getConn end") } req := treq.Request //trace := treq.trace @@ -824,13 +794,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } }() - // TODO(spongehah) ConnPool(t.getConn) // Queue for idle connection. if delivered := t.queueForIdleConn(w); delivered { pc := w.pc // Trace only for HTTP/1. // HTTP/2 calls trace.GotConn itself. - // TODO(spongehah) trace(t.getConn) //if pc.alt == nil && trace != nil && trace.GotConn != nil { // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) //} @@ -853,28 +821,28 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) //} if w.err != nil { - // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the - // cancellation error (see golang.org/issue/16049). - select { - // TODO(spongehah) timeout(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() - case <-req.timeoutch: - if debugSwitch { - println("getConn: timeoutch") - } - return nil, errors.New("timeout: req.Context().Err()") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err - default: - // return below + return nil, w.err + } + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case <-req.timeoutch: + if debugSwitch { + println("############### getConn: timeoutch") + } + return nil, errors.New("timeout: req.Context().Err()") + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn } + return nil, err + default: + // return below } return w.pc, w.err } @@ -883,8 +851,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { if debugSwitch { - println("queueForDial start") - defer println("queueForDial end") + println("############### queueForDial start") + defer println("############### queueForDial end") } w.beforeDial() @@ -919,13 +887,12 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { if debugSwitch { - println("dialConnFor start") - defer println("dialConnFor end") + println("############### dialConnFor start") + defer println("############### dialConnFor end") } defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) - // TODO(spongehah) ConnPool(t.dialConnFor) delivered := w.tryDeliver(pc, err) // If the connection was successfully established but was not passed to w, // or is a shareable HTTP/2 connection @@ -994,8 +961,8 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { if debugSwitch { - println("dialConn start") - defer println("dialConn end") + println("############### dialConn start") + defer println("############### dialConn end") } select { case <-timeoutch: @@ -1009,7 +976,9 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), alive: true, + chunkAsync: &libuv.Async{}, } + t.loop.Async(pconn.chunkAsync, readyToRead) //trace := httptrace.ContextClientTrace(ctx) //wrapErr := func(err error) error { @@ -1102,6 +1071,21 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // } //} + pconn.closeErr = errReadLoopExiting + pconn.tryPutIdleConn = func() bool { + if err := pconn.t.tryPutIdleConn(pconn); err != nil { + pconn.closeErr = err + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") @@ -1114,8 +1098,8 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * func (t *Transport) dial(addr string) (*connData, error) { if debugSwitch { - println("dial start") - defer println("dial end") + println("############### dial start") + defer println("############### dial end") } host, port, err := net.SplitHostPort(addr) if err != nil { @@ -1150,12 +1134,11 @@ func (t *Transport) dial(addr string) (*connData, error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { if debugSwitch { - println("roundTrip start") - defer println("roundTrip end") + println("############### roundTrip start") + defer println("############### roundTrip end") } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - // TODO(spongehah) ConnPool(pc.roundTrip) pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -1168,40 +1151,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err headerFn(req.extraHeaders()) } - // Ask for a compressed version if the caller didn't set their - // own value for Accept-Encoding. We only attempt to - // uncompress the gzip stream if we were the layer that - // requested it. - requestedGzip := false - // TODO(spongehah) gzip(pc.roundTrip) - //if !pc.t.DisableCompression && - // req.Header.Get("Accept-Encoding") == "" && - // req.Header.Get("Range") == "" && - // req.Method != "HEAD" { - // // Request gzip only, not deflate. Deflate is ambiguous and - // // not as universally supported anyway. - // // See: https://zlib.net/zlib_faq.html#faq39 - // // - // // Note that we don't request this for HEAD requests, - // // due to a bug in nginx: - // // https://trac.nginx.org/nginx/ticket/358 - // // https://golang.org/issue/5522 - // // - // // We don't request gzip if the request is for a range, since - // // auto-decoding a portion of a gzipped document will just fail - // // anyway. See https://golang.org/issue/8923 - // requestedGzip = true - // req.extraHeaders().Set("Accept-Encoding", "gzip") - //} - - // The 100-continue operation in Hyper is handled in the newHyperRequest function. - - // Keep-Alive - if pc.t.DisableKeepAlives && - !req.wantsClose() && - !isProtocolSwitchHeader(req.Header) { - req.extraHeaders().Set("Connection", "close") - } + // Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). + requestedGzip := pc.setExtraHeaders(req) gone := make(chan struct{}, 1) defer close(gone) @@ -1229,9 +1180,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } if pc.client == nil && !pc.isReused() { - println("first") // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) + hyperIo := newHyperIo(pc.conn) // We need an executor generally to poll futures // Prepare client options opts := hyper.NewClientConnOptions() @@ -1243,7 +1193,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // Send the request to readWriteLoop(). pc.t.exec.Push(handshakeTask) } else { - println("second") taskData.taskId = read err = req.write(pc.client, taskData, pc.t.exec) if err != nil { @@ -1264,12 +1213,12 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() if debugSwitch { - println("roundTrip for") + println("############### roundTrip for") } select { case err := <-writeErrCh: if debugSwitch { - println("roundTrip: writeErrch") + println("############### roundTrip: writeErrch") } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) @@ -1278,17 +1227,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return nil, pc.mapRoundTripError(req, startBytesWritten, err) } - //if d := pc.t.ResponseHeaderTimeout; d > 0 { - // if debugRoundTrip { - // //req.logf("starting timer for %v", d) - // } - // timer := time.NewTimer(d) - // defer timer.Stop() // prevent leaks - // respHeaderTimer = timer.C - //} case <-pcClosed: if debugSwitch { - println("roundTrip: pcClosed") + println("############### roundTrip: pcClosed") } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { @@ -1297,7 +1238,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err //case <-respHeaderTimer: case re := <-resc: if debugSwitch { - println("roundTrip: resc") + println("############### roundTrip: resc") } if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) @@ -1306,7 +1247,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - // TODO(spongehah) timeout(pc.roundTrip) //case <-cancelChan: // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) // cancelChan = nil @@ -1316,7 +1256,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // ctxDoneChan = nil case <-timeoutch: if debugSwitch { - println("roundTrip: timeoutch") + println("############### roundTrip: timeoutch") } canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) timeoutch = nil @@ -1330,361 +1270,232 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err func readWriteLoop(checker *libuv.Check) { t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) - // Read this once, before loop starts. (to avoid races in tests) - //testHookMu.Lock() - //testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - //testHookMu.Unlock() - - const debugReadWriteLoop = true // Debug switch provided for developers - - // The polling state machine! - // Poll all ready tasks and act on them... - for { - task := t.exec.Poll() + // The polling state machine! Poll all ready tasks and act on them... + task := t.exec.Poll() + for task != nil { if debugSwitch { - println("polling") - } - if task == nil { - return - } - taskData := (*taskData)(task.Userdata()) - var taskId taskId - if taskData != nil { - taskId = taskData.taskId - } else { - taskId = notSet + println("############### polling") } + t.handleTask(task) + task = t.exec.Poll() + } +} + +func (t *Transport) handleTask(task *hyper.Task) { + taskData := (*taskData)(task.Userdata()) + if taskData == nil { + // A background task for hyper_client completed... + task.Free() + return + } + var err error + pc := taskData.pc + // If original taskId is set, we need to check it + err = checkTaskType(task, taskData) + if err != nil { + readLoopDefer(pc, true) + return + } + switch taskData.taskId { + case handshake: if debugReadWriteLoop { - println("taskId: ", taskId) + println("############### write") } - switch taskId { - case handshake: - if debugReadWriteLoop { - println("write") - } - - err := checkTaskType(task, handshake) - if err != nil { - taskData.writeErrCh <- err - task.Free() - continue - } - - pc := taskData.pc - select { - case <-pc.closech: - task.Free() - continue - default: - } - pc.client = (*hyper.ClientConn)(task.Value()) + // Check if the connection is closed + select { + case <-pc.closech: task.Free() + return + default: + } - // TODO(spongehah) Proxy(writeLoop) - taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, t.exec) + pc.client = (*hyper.ClientConn)(task.Value()) + task.Free() - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } - - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } - - pc := taskData.pc - - err := checkTaskType(task, read) - if bre, ok := err.(requestBodyReadError); ok { - err = bre.error - // Errors reading from the user's - // Request.Body are high priority. - // Set it here before sending on the - // channels below or calling - // pc.close() which tears down - // connections and causes other - // errors. - taskData.req.setError(err) - } - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } + // TODO(spongehah) Proxy(writeLoop) + taskData.taskId = read + err = taskData.req.Request.write(pc.client, taskData, t.exec) - if pc.closeErr == nil { - pc.closeErr = errReadLoopExiting - } - // TODO(spongehah) ConnPool(readWriteLoop) - if pc.tryPutIdleConn == nil { - pc.tryPutIdleConn = func() bool { - if err := pc.t.tryPutIdleConn(pc); err != nil { - pc.closeErr = err - // TODO(spongehah) trace(readWriteLoop) - //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - //} - return false - } - //if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - //} - return true - } - } + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + pc.close(err) + return + } - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() + if debugReadWriteLoop { + println("############### write end") + } + case read: + if debugReadWriteLoop { + println("############### read") + } - pc.mu.Lock() - if pc.numExpectedResponses == 0 { - pc.readLoopPeekFailLocked(hyperResp, err) - pc.mu.Unlock() + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() - // defer - readLoopDefer(pc, t) - continue - } + //pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() + readLoopDefer(pc, true) + return + } + //pc.mu.Unlock() - //trace := httptrace.ContextClientTrace(rc.req.Context()) - - var resp *Response - var respBody *hyper.Body - if err == nil { - var pr *io.PipeReader - pr, taskData.bodyWriter = io.Pipe() - resp, err = ReadResponse(pr, taskData.req.Request, hyperResp) - respBody = hyperResp.Body() - } else { - err = transportReadFromServerError{err} - pc.closeErr = err - } + var resp *Response + if err == nil { + pc.chunkAsync.SetData(c.Pointer(taskData)) + bc := newBodyChunk(pc.chunkAsync) + pc.bodyChunk = bc + resp, err = ReadResponse(bc, taskData.req.Request, hyperResp) + taskData.hyperBody = hyperResp.Body() + } else { + err = transportReadFromServerError{err} + pc.closeErr = err + } - // No longer need the response - hyperResp.Free() + // No longer need the response + hyperResp.Free() - if err != nil { - select { - case taskData.resc <- responseAndError{err: err}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // defer - readLoopDefer(pc, t) - continue + if err != nil { + select { + case taskData.resc <- responseAndError{err: err}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return } + readLoopDefer(pc, true) + return + } - pc.mu.Lock() - pc.numExpectedResponses-- - pc.mu.Unlock() - - bodyWritable := resp.bodyIsWritable() - hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 + dataTask := taskData.hyperBody.Data() + taskData.taskId = readBodyChunk + dataTask.SetUserdata(c.Pointer(taskData)) + t.exec.Push(dataTask) - if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { - // Don't do keep-alive on error if either party requested a close - // or we get an unexpected informational (1xx) response. - // StatusCode 100 is already handled above. - pc.alive = false - } + if !taskData.req.deadline.IsZero() { + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + } - if !hasBody || bodyWritable { - replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) - - // TODO(spongehah) ConnPool(readWriteLoop) - // Put the idle conn back into the pool before we send the response - // so if they process it quickly and make another request, they'll - // get this same conn. But we use the unbuffered channel 'rc' - // to guarantee that persistConn.roundTrip got out of its select - // potentially waiting for this persistConn to close. - pc.alive = pc.alive && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && pc.tryPutIdleConn() - - if bodyWritable { - pc.closeErr = errCallerOwnsConn - } + //pc.mu.Lock() + pc.numExpectedResponses-- + //pc.mu.Unlock() - select { - case taskData.resc <- responseAndError{res: resp}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // Now that they've read from the unbuffered channel, they're safely - // out of the select that also waits on this goroutine to die, so - // we're allowed to exit now if needed (if alive is false) - //testHookReadLoopBeforeNextRead() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } - continue - } + needContinue := resp.checkRespBody(taskData) + if needContinue { + return + } - body := &bodyEOFSignal{ - body: resp.Body, - earlyCloseFn: func() error { - taskData.bodyWriter.Close() - return nil - }, - fn: func(err error) error { - isEOF := err == io.EOF - if !isEOF { - if cerr := pc.canceled(); cerr != nil { - return cerr - } - } - return err - }, - } - resp.Body = body - - // TODO(spongehah) gzip(pc.readWriteLoop) - //if taskData.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - // println("gzip reader") - // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") - // resp.Header.Del("Content-Length") - // resp.ContentLength = -1 - // resp.Uncompressed = true - //} + resp.wrapRespBody(taskData) - bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(taskData.bodyWriter)) - taskData.taskId = readDone - bodyForeachTask.SetUserdata(c.Pointer(taskData)) - t.exec.Push(bodyForeachTask) - if taskData.req.timer != nil { - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData - } + // FIXME: Waiting for the channel bug to be fixed + //select { + //case taskData.resc <- responseAndError{res: resp}: + //case <-taskData.callerGone: + // // defer + // readLoopDefer(pc, true) + // return + //} + select { + case <-taskData.callerGone: + readLoopDefer(pc, true) + return + default: + } + taskData.resc <- responseAndError{res: resp} - // TODO(spongehah) select blocking(readWriteLoop) - //select { - //case taskData.resc <- responseAndError{res: resp}: - //case <-taskData.callerGone: - // // defer - // readLoopDefer(pc, t) - // continue - //} - select { - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - default: - } - taskData.resc <- responseAndError{res: resp} + if debugReadWriteLoop { + println("############### read end") + } + case readBodyChunk: + if debugReadWriteLoop { + println("############### readBodyChunk") + } + taskType := task.Type() + if taskType == hyper.TaskBuf { + chunk := (*hyper.Buf)(task.Value()) + chunkLen := chunk.Len() + bytes := unsafe.Slice(chunk.Bytes(), chunkLen) + // Free chunk and task + chunk.Free() + task.Free() + // Write to the channel + pc.bodyChunk.readCh <- bytes if debugReadWriteLoop { - println("read end") - } - case readDone: - // A background task of reading the response body is completed - if debugReadWriteLoop { - println("readDone") - } - if taskData.bodyWriter != nil { - taskData.bodyWriter.Close() + println("############### readBodyChunk end [buf]") } - checkTaskType(task, readDone) - - bodyEOF := task.Type() == hyper.TaskEmpty - // free the task - task.Free() - - pc := taskData.pc - - replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool - // TODO(spongehah) ConnPool(readWriteLoop) - pc.alive = pc.alive && - bodyEOF && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // bodyEOF && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - // TODO(spongehah) timeout(t.readWriteLoop) - //case <-rw.rc.req.Cancel: - // pc.alive = false - // pc.t.CancelRequest(rw.rc.req) - //case <-rw.rc.req.Context().Done(): - // pc.alive = false - // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - //case <-pc.closech: - // pc.alive = false - //} + return + } - //select { - //case <-taskData.req.timeoutch: - // continue - //case <-pc.closech: - // pc.alive = false - //default: - //} + // taskType == taskEmpty (check in checkTaskType) + task.Free() + taskData.hyperBody.Free() + taskData.hyperBody = nil + pc.bodyChunk.Close() + replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } + readLoopDefer(pc, false) - //testHookReadLoopBeforeNextRead() - if debugReadWriteLoop { - println("readDone end") - } - case notSet: - // A background task for hyper_client completed... - task.Free() + if debugReadWriteLoop { + println("############### readBodyChunk end [empty]") } } } -func readLoopDefer(pc *persistConn, t *Transport) { +func readyToRead(aysnc *libuv.Async) { + println("############### AsyncCb: readyToRead") + taskData := (*taskData)(aysnc.GetData()) + dataTask := taskData.hyperBody.Data() + dataTask.SetUserdata(c.Pointer(taskData)) + taskData.pc.t.exec.Push(dataTask) +} + +// readLoopDefer Replace the defer function of readLoop in stdlib +func readLoopDefer(pc *persistConn, force bool) { + if pc.alive == true && !force { + return + } pc.close(pc.closeErr) - // TODO(spongehah) ConnPool(readLoopDefer) - t.removeIdleConn(pc) + pc.t.removeIdleConn(pc) } // ---------------------------------------------------------- +type connData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + nwrite int64 // bytes written(Replaced from persistConn's nwrite) + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + type taskData struct { taskId taskId - bodyWriter *io.PipeWriter req *transportRequest pc *persistConn addedGzip bool writeErrCh chan error callerGone chan struct{} resc chan responseAndError + hyperBody *hyper.Body } -type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr - nwrite int64 // bytes written(Replaced from persistConn's nwrite) - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int + +const ( + handshake taskId = iota + 1 + read + readBodyChunk +) func (conn *connData) Close() error { if conn == nil { @@ -1709,8 +1520,8 @@ func (conn *connData) Close() error { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { if debugSwitch { - println("connect start") - defer println("connect end") + println("############### connect start") + defer println("############### connect end") } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) @@ -1723,8 +1534,6 @@ func onConnect(req *libuv.Connect, status c.Int) { // allocBuffer allocates a buffer for reading from a socket func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { - //conn := (*ConnData)(handle.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data conn := (*connData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) @@ -1738,31 +1547,21 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { // onRead is the libuv callback for reading from a socket // This callback function is called when data is available to be read func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { - // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - - // If data was read (nread > 0) if nread > 0 { - // Update the amount of filled buffer conn.ReadBufFilled += uintptr(nread) } - // If there's a pending read waker if conn.ReadWaker != nil { // Wake up the pending read operation of Hyper conn.ReadWaker.Wake() - // Clear the waker reference conn.ReadWaker = nil } } // readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - - // If there's data in the buffer if conn.ReadBufFilled > 0 { - // Calculate how much data to copy (minimum of filled amount and requested amount) var toCopy uintptr if bufLen < conn.ReadBufFilled { toCopy = bufLen @@ -1775,71 +1574,52 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) // Update the amount of filled buffer conn.ReadBufFilled -= toCopy - // Return the number of bytes copied return toCopy } - // If no data in buffer, set up a waker to wait for more data - // Free the old waker if it exists if conn.ReadWaker != nil { conn.ReadWaker.Free() } - // Create a new waker conn.ReadWaker = ctx.Waker() - // Return HYPER_IO_PENDING to indicate operation is pending, waiting for more data return hyper.IoPending } // onWrite is the libuv callback for writing to a socket // Callback function called after a write operation completes func onWrite(req *libuv.Write, status c.Int) { - // Get the connection data associated with the write request conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - - // If there's a pending write waker if conn.WriteWaker != nil { // Wake up the pending write operation conn.WriteWaker.Wake() - // Clear the waker reference conn.WriteWaker = nil } } // writeCallBack write callback function for Hyper library func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) req := &libuv.Write{} - // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - // Perform the asynchronous write operation ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) - // If the write operation was successfully initiated if ret >= 0 { conn.nwrite += int64(bufLen) - // Return the number of bytes to be written return bufLen } - // If the write operation can't complete immediately, set up a waker to wait for completion if conn.WriteWaker != nil { - // Free the old waker if it exists conn.WriteWaker.Free() } - // Create a new waker conn.WriteWaker = ctx.Waker() - // Return HYPER_IO_PENDING to indicate operation is pending, waiting for write to complete return hyper.IoPending } // onTimeout is the libuv callback for a timeout func onTimeout(timer *libuv.Timer) { if debugSwitch { - println("onTimeout start") - defer println("onTimeout end") + println("############### onTimeout start") + defer println("############### onTimeout end") } data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) close(data.timeoutch) @@ -1850,13 +1630,12 @@ func onTimeout(timer *libuv.Timer) { pc := taskData.pc pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) - // defer - readLoopDefer(pc, pc.t) + readLoopDefer(pc, true) } } -// newIoWithConnReadWrite creates a new IO with read and write callbacks -func newIoWithConnReadWrite(connData *connData) *hyper.Io { +// newHyperIo creates a new IO with read and write callbacks +func newHyperIo(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) hyperIo.SetRead(readCallBack) @@ -1864,79 +1643,98 @@ func newIoWithConnReadWrite(connData *connData) *hyper.Io { return hyperIo } -// taskId The unique identifier of the next task polled from the executor -type taskId c.Int - -const ( - notSet taskId = iota - handshake - read - readDone -) - // checkTaskType checks the task type -func checkTaskType(task *hyper.Task, curTaskId taskId) error { - switch curTaskId { - case handshake: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::handshake]handshake task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("[readWriteLoop::handshake]unexpected task type\n") - } - return nil - case read: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::read]write task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("[readWriteLoop::read]unexpected task type\n")) - return errors.New("[readWriteLoop::read]unexpected task type\n") +func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { + curTaskId := taskData.taskId + taskType := task.Type() + if taskType == hyper.TaskError { + err = fail((*hyper.Error)(task.Value()), curTaskId) + } + if err == nil { + switch curTaskId { + case handshake: + if taskType != hyper.TaskClientConn { + err = errors.New("Unexpected hyper task type: expected to be TaskClientConn, actual is " + strTaskType(taskType)) + } + case read: + if taskType != hyper.TaskResponse { + err = errors.New("Unexpected hyper task type: expected to be TaskResponse, actual is " + strTaskType(taskType)) + } + case readBodyChunk: + if taskType != hyper.TaskBuf && taskType != hyper.TaskEmpty { + err = errors.New("Unexpected hyper task type: expected to be TaskBuf / TaskEmpty, actual is " + strTaskType(taskType)) + } } - return nil - case readDone: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::readDone]read response body error!\n") - return fail((*hyper.Error)(task.Value())) + } + if err != nil { + task.Free() + if curTaskId == handshake || curTaskId == read { + taskData.writeErrCh <- err + taskData.pc.close(err) } - return nil - case notSet: + taskData.pc.alive = false } - return errors.New("[readWriteLoop]unexpected task type\n") + return } // fail prints the error details and panics -func fail(err *hyper.Error) error { +func fail(err *hyper.Error, taskId taskId) error { if err != nil { - c.Printf(c.Str("[readWriteLoop]error code: %d\n"), err.Code()) // grab the error details var errBuf [256]c.Char errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - - c.Printf(c.Str("[readWriteLoop]details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + errDetails := unsafe.SliceData(errBuf[:errLen]) + details := c.GoString(errDetails) // clean up the error err.Free() - return fmt.Errorf("[readWriteLoop]hyper request error, error code: %d\n", int(err.Code())) + return fmt.Errorf("hyper request error, taskId: %s, details: %s\n", strTaskId(taskId), details) } return nil } +func strTaskType(taskType hyper.TaskReturnType) string { + switch taskType { + case hyper.TaskClientConn: + return "TaskClientConn" + case hyper.TaskResponse: + return "TaskResponse" + case hyper.TaskBuf: + return "TaskBuf" + case hyper.TaskEmpty: + return "TaskEmpty" + case hyper.TaskError: + return "TaskError" + default: + return "Unknown" + } +} + +func strTaskId(taskId taskId) string { + switch taskId { + case handshake: + return "handshake" + case read: + return "read" + case readBodyChunk: + return "readBodyChunk" + default: + return "notSet" + } +} + // ---------------------------------------------------------- // error values for debugging and testing, not seen by users. var ( - errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") - errConnBroken = errors.New("http: putIdleConn: connection is in bad state") - errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") - errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") - errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") - errCloseIdleConns = errors.New("http: CloseIdleConnections called") - errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") - errReadWriteLoopExiting = errors.New("http: Transport.readWriteLoop exiting") - errIdleConnTimeout = errors.New("http: idle connection timeout") + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -1971,14 +1769,6 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } -// fakeLocker is a sync.Locker which does nothing. It's used to guard -// test-only fields when not under test, to avoid runtime atomic -// overhead. -type fakeLocker struct{} - -func (fakeLocker) Lock() {} -func (fakeLocker) Unlock() {} - // nothingWrittenError wraps a write errors which ended up writing zero bytes. type nothingWrittenError struct { error @@ -2014,9 +1804,6 @@ var ( testHookRoundTripRetried = nop testHookPrePendingDial = nop testHookPostPendingDial = nop - - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead = nop ) var portMap = map[string]string{ @@ -2076,10 +1863,34 @@ type persistConn struct { mutateHeaderFunc func(Header) // other - alive bool // Replace the alive in readLoop - closeErr error // Replace the closeErr in readLoop - tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop - client *hyper.ClientConn + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop + tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop + client *hyper.ClientConn // http long connection client handle + bodyChunk *bodyChunk // Implement non-blocking consumption of each responseBody chunk + chunkAsync *libuv.Async // Notifying that the received chunk has been read +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.closeIdle = true // close newly idle connections + t.idleLRU = connLRU{} + t.idleMu.Unlock() + for _, conns := range m { + for _, pconn := range conns { + pconn.close(errCloseIdleConns) + } + } + //if t2 := t.h2transport; t2 != nil { + // t2.CloseIdleConnections() + //} } func (pc *persistConn) cancelRequest(err error) { @@ -2110,7 +1921,7 @@ func (pc *persistConn) markReused() { func (pc *persistConn) closeLocked(err error) { if debugSwitch { - println("pc closed") + println("############### pc closed") } if err == nil { panic("nil error") @@ -2128,6 +1939,7 @@ func (pc *persistConn) closeLocked(err error) { close(pc.closech) close(pc.writeLoopDone) pc.client.Free() + pc.chunkAsync.Close(nil) } } pc.mutateHeaderFunc = nil @@ -2256,13 +2068,11 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // the 1st response byte from the server. return true } - if err == errServerClosedIdle { - // The server replied with io.EOF while we were trying to - // read the response. Probably an unfortunately keep-alive - // timeout, just as the client was writing a request. - return true - } - return false // conservatively + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + // conservatively return false. + return err == errServerClosedIdle } // closeConnIfStillIdle closes the connection if it's still sitting idle. @@ -2300,6 +2110,45 @@ func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", err)) } +// setExtraHeaders Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). +func (pc *persistConn) setExtraHeaders(req *transportRequest) bool { + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + // TODO(spongehah) gzip(pc.roundTrip) + //if !pc.t.DisableCompression && + // req.Header.Get("Accept-Encoding") == "" && + // req.Header.Get("Range") == "" && + // req.Method != "HEAD" { + // // Request gzip only, not deflate. Deflate is ambiguous and + // // not as universally supported anyway. + // // See: https://zlib.net/zlib_faq.html#faq39 + // // + // // Note that we don't request this for HEAD requests, + // // due to a bug in nginx: + // // https://trac.nginx.org/nginx/ticket/358 + // // https://golang.org/issue/5522 + // // + // // We don't request gzip if the request is for a range, since + // // auto-decoding a portion of a gzipped document will just fail + // // anyway. See https://golang.org/issue/8923 + // requestedGzip = true + // req.extraHeaders().Set("Accept-Encoding", "gzip") + //} + + // The 100-continue operation in Hyper is handled in the newHyperRequest function. + + // Keep-Alive + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + return requestedGzip +} + func is408Message(resp *hyper.Response) bool { httpVersion := int(resp.Version()) if httpVersion != 10 && httpVersion != 11 { @@ -2435,7 +2284,6 @@ func (w *wantConn) cancel(t *Transport, err error) { w.err = err w.mu.Unlock() - // TODO(spongehah) ConnPool(w.cancel) if pc != nil { t.putOrCloseIdleConn(pc) } @@ -2534,110 +2382,6 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } -// bodyEOFSignal is used by the HTTP/1 transport when reading response -// bodies to make sure we see the end of a response body before -// proceeding and reading on the connection again. -// -// It wraps a ReadCloser but runs fn (if non-nil) at most -// once, right before its final (error-producing) Read or Close call -// returns. fn should return the new error to return from Read or Close. -// -// If earlyCloseFn is non-nil and Close is called before io.EOF is -// seen, earlyCloseFn is called instead of fn, and its return value is -// the return value from Close. -type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards following 4 fields - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) error // err will be nil on Read io.EOF - earlyCloseFn func() error // optional alt Close func used if io.EOF not seen -} - -var errReadOnClosedResBody = errors.New("http: read on closed response body") - -func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { - es.mu.Lock() - closed, rerr := es.closed, es.rerr - es.mu.Unlock() - if closed { - return 0, errReadOnClosedResBody - } - if rerr != nil { - return 0, rerr - } - - n, err = es.body.Read(p) - if err != nil { - es.mu.Lock() - defer es.mu.Unlock() - if es.rerr == nil { - es.rerr = err - } - err = es.condfn(err) - } - return -} - -func (es *bodyEOFSignal) Close() error { - es.mu.Lock() - defer es.mu.Unlock() - if es.closed { - return nil - } - es.closed = true - if es.earlyCloseFn != nil && es.rerr != io.EOF { - return es.earlyCloseFn() - } - err := es.body.Close() - return es.condfn(err) -} - -// caller must hold es.mu. -func (es *bodyEOFSignal) condfn(err error) error { - if es.fn == nil { - return err - } - err = es.fn(err) - es.fn = nil - return err -} - -// gzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read -type gzipReader struct { - _ incomparable - body *bodyEOFSignal // underlying HTTP/1 response body framing - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // any error from gzip.NewReader; sticky -} - -func (gz *gzipReader) Read(p []byte) (n int, err error) { - if gz.zr == nil { - if gz.zerr == nil { - gz.zr, gz.zerr = gzip.NewReader(gz.body) - } - if gz.zerr != nil { - return 0, gz.zerr - } - } - - gz.body.mu.Lock() - if gz.body.closed { - err = errReadOnClosedResBody - } - gz.body.mu.Unlock() - - if err != nil { - return 0, err - } - return gz.zr.Read(p) -} - -func (gz *gzipReader) Close() error { - return gz.body.Close() -} - type connLRU struct { ll *list.List // list.Element.Value type is of *persistConn m map[*persistConn]*list.Element From a4e86311ecd8999669de14624189c4b42de2c151 Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Fri, 20 Sep 2024 12:48:10 +0800 Subject: [PATCH 21/21] WIP(x/net/http/client): Mutiple eventLoop --- go.mod | 2 +- go.sum | 4 +- .../_demo/parallelRequest/parallelRequest.go | 43 ++ x/net/http/bodyChunk.go | 84 +-- x/net/http/client.go | 193 +++--- x/net/http/request.go | 73 +-- x/net/http/response.go | 23 +- x/net/http/server.go | 7 - x/net/http/transfer.go | 60 +- x/net/http/transport.go | 597 ++++++++++-------- 10 files changed, 525 insertions(+), 561 deletions(-) create mode 100644 x/net/http/_demo/parallelRequest/parallelRequest.go diff --git a/go.mod b/go.mod index e893515..95f17e6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/goplus/llgoexamples go 1.20 require ( - github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b + github.com/goplus/llgo v0.9.8-0.20240919105235-c6436ea6d196 golang.org/x/net v0.28.0 ) diff --git a/go.sum b/go.sum index 5d7faad..08150d6 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b h1:iC0vVA8F2DNJ9wVyHI9fP9U0nM+si3LSQJ1TtGftXyo= -github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +github.com/goplus/llgo v0.9.8-0.20240919105235-c6436ea6d196 h1:LckJktvgChf3x0eex+GT//JkYRj1uiT4uMLzyrg3ChU= +github.com/goplus/llgo v0.9.8-0.20240919105235-c6436ea6d196/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= diff --git a/x/net/http/_demo/parallelRequest/parallelRequest.go b/x/net/http/_demo/parallelRequest/parallelRequest.go new file mode 100644 index 0000000..0bcb336 --- /dev/null +++ b/x/net/http/_demo/parallelRequest/parallelRequest.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func worker(id int, wg *sync.WaitGroup) { + defer wg.Done() + resp, err := http.Get("http://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(id, ":", resp.Status) + //body, err := io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println(err) + // return + //} + //fmt.Println(string(body)) + resp.Body.Close() +} + +func main() { + var wait sync.WaitGroup + for i := 0; i < 500; i++ { + wait.Add(1) + go worker(i, &wait) + } + wait.Wait() + fmt.Println("All done") + + resp, err := http.Get("http://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status) + resp.Body.Close() +} diff --git a/x/net/http/bodyChunk.go b/x/net/http/bodyChunk.go index c1d1072..01d9e74 100644 --- a/x/net/http/bodyChunk.go +++ b/x/net/http/bodyChunk.go @@ -2,73 +2,49 @@ package http import ( "errors" - "io" - "sync" "github.com/goplus/llgo/c/libuv" ) -type onceError struct { - sync.Mutex - err error -} - -func (a *onceError) Store(err error) { - a.Lock() - defer a.Unlock() - if a.err != nil { - return - } - a.err = err -} - -func (a *onceError) Load() error { - a.Lock() - defer a.Unlock() - return a.err -} - -func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { - return &bodyChunk{ - readCh: make(chan []byte, 1), - done: make(chan struct{}), - asyncHandle: asyncHandle, - } -} - type bodyChunk struct { chunk []byte readCh chan []byte asyncHandle *libuv.Async - once sync.Once done chan struct{} - rerr onceError + rerr error } var ( errClosedBodyChunk = errors.New("bodyChunk: read/write on closed body") ) +func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { + return &bodyChunk{ + readCh: make(chan []byte, 1), + done: make(chan struct{}), + asyncHandle: asyncHandle, + } +} + func (bc *bodyChunk) Read(p []byte) (n int, err error) { + select { + case <-bc.done: + err = bc.readCloseError() + return + default: + } + for n < len(p) { if len(bc.chunk) == 0 { + bc.asyncHandle.Send() select { - case chunk, ok := <-bc.readCh: - if !ok { - if n > 0 { - return n, nil - } - return 0, bc.readCloseError() - } + case chunk := <-bc.readCh: bc.chunk = chunk - bc.asyncHandle.Send() case <-bc.done: - if n > 0 { - return n, nil - } - return 0, io.EOF + err = bc.readCloseError() + return } } @@ -77,28 +53,28 @@ func (bc *bodyChunk) Read(p []byte) (n int, err error) { bc.chunk = bc.chunk[copied:] } - return n, nil + return } func (bc *bodyChunk) Close() error { - return bc.closeRead(nil) + return bc.closeWithError(nil) } func (bc *bodyChunk) readCloseError() error { - if rerr := bc.rerr.Load(); rerr != nil { + if rerr := bc.rerr; rerr != nil { return rerr } return errClosedBodyChunk } -func (bc *bodyChunk) closeRead(err error) error { +func (bc *bodyChunk) closeWithError(err error) error { + if bc.rerr != nil { + return nil + } if err == nil { - err = io.EOF + err = errClosedBodyChunk } - bc.rerr.Store(err) - bc.once.Do(func() { - close(bc.done) - }) - //close(bc.done) + bc.rerr = err + close(bc.done) return nil } diff --git a/x/net/http/client.go b/x/net/http/client.go index 7e26395..fa62732 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -11,8 +11,6 @@ import ( "reflect" "sort" "strings" - "sync" - "sync/atomic" "time" ) @@ -157,8 +155,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { URL: u, Header: make(Header), Host: host, - Cancel: ireq.Cancel, - ctx: ireq.ctx, + //Cancel: ireq.Cancel, timer: ireq.timer, timeoutch: ireq.timeoutch, @@ -307,16 +304,15 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) tmp timeout(send) + // TODO(hah) tmp timeout(send): LLGo has not yet implemented startTimer. //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) req.timeoutch = make(chan struct{}, 1) req.deadline = deadline - req.ctx.Done() if deadline.IsZero() { didTimeout = alwaysFalse defer close(req.timeoutch) } else { - didTimeout = func() bool { return req.timer.GetDueIn() == 0 } + didTimeout = func() bool { return time.Now().After(deadline) } } resp, err = rt.RoundTrip(req) @@ -478,110 +474,83 @@ func (b *cancelTimerBody) Close() error { return err } -// knownRoundTripperImpl reports whether rt is a RoundTripper that's -// maintained by the Go team and known to implement the latest -// optional semantics (notably contexts). The Request is used -// to check whether this particular request is using an alternate protocol, -// in which case we need to check the RoundTripper for that protocol. -func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { - switch t := rt.(type) { - case *Transport: - if altRT := t.alternateRoundTripper(req); altRT != nil { - return knownRoundTripperImpl(altRT, req) - } - return true - //case *http2Transport, http2noDialH2RoundTripper: - // return true - } - // There's a very minor chance of a false positive with this. - // Instead of detecting our golang.org/x/net/http2.Transport, - // it might detect a Transport type in a different http2 - // package. But I know of none, and the only problem would be - // some temporarily leaked goroutines if the transport didn't - // support contexts. So this is a good enough heuristic: - if reflect.TypeOf(rt).String() == "*http2.Transport" { - return true - } - return false -} - -// setRequestCancel sets req.Cancel and adds a deadline context to req -// if deadline is non-zero. The RoundTripper's type is used to -// determine whether the legacy CancelRequest behavior should be used. +//// setRequestCancel sets req.Cancel and adds a deadline context to req +//// if deadline is non-zero. The RoundTripper's type is used to +//// determine whether the legacy CancelRequest behavior should be used. +//// +//// As background, there are three ways to cancel a request: +//// First was Transport.CancelRequest. (deprecated) +//// Second was Request.Cancel. +//// Third was Request.Context. +//// This function populates the second and third, and uses the first if it really needs to. +//func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { +// if deadline.IsZero() { +// return nop, alwaysFalse +// } +// knownTransport := knownRoundTripperImpl(rt, req) +// oldCtx := req.Context() // -// As background, there are three ways to cancel a request: -// First was Transport.CancelRequest. (deprecated) -// Second was Request.Cancel. -// Third was Request.Context. -// This function populates the second and third, and uses the first if it really needs to. -func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { - if deadline.IsZero() { - return nop, alwaysFalse - } - knownTransport := knownRoundTripperImpl(rt, req) - oldCtx := req.Context() - - if req.Cancel == nil && knownTransport { - // If they already had a Request.Context that's - // expiring sooner, do nothing: - if !timeBeforeContextDeadline(deadline, oldCtx) { - return nop, alwaysFalse - } - - var cancelCtx func() - req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) - return cancelCtx, func() bool { return time.Now().After(deadline) } - } - initialReqCancel := req.Cancel // the user's original Request.Cancel, if any - - var cancelCtx func() - if timeBeforeContextDeadline(deadline, oldCtx) { - req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) - } - - cancel := make(chan struct{}) - req.Cancel = cancel - - doCancel := func() { - // The second way in the func comment above: - close(cancel) - // The first way, used only for RoundTripper - // implementations written before Go 1.5 or Go 1.6. - type canceler interface{ CancelRequest(*Request) } - if v, ok := rt.(canceler); ok { - v.CancelRequest(req) - } - } - - stopTimerCh := make(chan struct{}) - var once sync.Once - stopTimer = func() { - once.Do(func() { - close(stopTimerCh) - if cancelCtx != nil { - cancelCtx() - } - }) - } - - timer := time.NewTimer(time.Until(deadline)) - var timedOut atomic.Bool - - go func() { - select { - case <-initialReqCancel: - doCancel() - timer.Stop() - case <-timer.C: - timedOut.Store(true) - doCancel() - case <-stopTimerCh: - timer.Stop() - } - }() - - return stopTimer, timedOut.Load -} +// if req.Cancel == nil && knownTransport { +// // If they already had a Request.Context that's +// // expiring sooner, do nothing: +// if !timeBeforeContextDeadline(deadline, oldCtx) { +// return nop, alwaysFalse +// } +// +// var cancelCtx func() +// req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) +// return cancelCtx, func() bool { return time.Now().After(deadline) } +// } +// initialReqCancel := req.Cancel // the user's original Request.Cancel, if any +// +// var cancelCtx func() +// if timeBeforeContextDeadline(deadline, oldCtx) { +// req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) +// } +// +// cancel := make(chan struct{}) +// req.Cancel = cancel +// +// doCancel := func() { +// // The second way in the func comment above: +// close(cancel) +// // The first way, used only for RoundTripper +// // implementations written before Go 1.5 or Go 1.6. +// type canceler interface{ CancelRequest(*Request) } +// if v, ok := rt.(canceler); ok { +// v.CancelRequest(req) +// } +// } +// +// stopTimerCh := make(chan struct{}) +// var once sync.Once +// stopTimer = func() { +// once.Do(func() { +// close(stopTimerCh) +// if cancelCtx != nil { +// cancelCtx() +// } +// }) +// } +// +// timer := time.NewTimer(time.Until(deadline)) +// var timedOut atomic.Bool +// +// go func() { +// select { +// case <-initialReqCancel: +// doCancel() +// timer.Stop() +// case <-timer.C: +// timedOut.Store(true) +// doCancel() +// case <-stopTimerCh: +// timer.Stop() +// } +// }() +// +// return stopTimer, timedOut.Load +//} // timeBeforeContextDeadline reports whether the non-zero Time t is // before ctx's deadline, if any. If ctx does not have a deadline, it @@ -594,7 +563,7 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { return t.Before(d) } -/*// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// knownRoundTripperImpl reports whether rt is a RoundTripper that's // maintained by the Go team and known to implement the latest // optional semantics (notably contexts). The Request is used // to check whether this particular request is using an alternate protocol, @@ -619,7 +588,7 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return true } return false -}*/ +} // makeHeadersCopier makes a function that copies headers from the // initial Request, ireq. For every redirect, this function must be called diff --git a/x/net/http/request.go b/x/net/http/request.go index e9279fc..37d6408 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -2,7 +2,6 @@ package http import ( "bytes" - "context" "errors" "fmt" "io" @@ -31,20 +30,16 @@ type Request struct { TransferEncoding []string Close bool Host string - //Form url.Values - //PostForm url.Values - //MultipartForm *multipart.Form - Trailer Header + // Form url.Values + // PostForm url.Values + // MultipartForm *multipart.Form RemoteAddr string RequestURI string - //TLS *tls.ConnectionState - Cancel <-chan struct{} Response *Response - ctx context.Context deadline time.Time - timeoutch chan struct{} //tmp timeout + timeoutch chan struct{} timer *libuv.Timer } @@ -75,34 +70,8 @@ var reqWriteExcludeHeader = map[string]bool{ type requestBodyReadError struct{ error } // NewRequest wraps NewRequestWithContext using context.Background. -func NewRequest(method, url string, body io.Reader) (*Request, error) { - return NewRequestWithContext(context.Background(), method, url, body) -} - -// NewRequestWithContext returns a new Request given a method, URL, and -// optional body. -// -// If the provided body is also an io.Closer, the returned -// Request.Body is set to body and will be closed by the Client -// methods Do, Post, and PostForm, and Transport.RoundTrip. -// -// NewRequestWithContext returns a Request suitable for use with -// Client.Do or Transport.RoundTrip. To create a request for use with -// testing a Server Handler, either use the NewRequest function in the -// net/http/httptest package, use ReadRequest, or manually update the -// Request fields. For an outgoing client request, the context -// controls the entire lifetime of a request and its response: -// obtaining a connection, sending the request, and reading the -// response headers and body. See the Request type's documentation for -// the difference between inbound and outbound request fields. -// -// If body is of type *bytes.Buffer, *bytes.Reader, or -// *strings.Reader, the returned request's ContentLength is set to its -// exact value (instead of -1), GetBody is populated (so 307 and 308 -// redirects can replay the body), and Body is set to NoBody if the -// ContentLength is 0. -func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.Reader) (*Request, error) { - // TODO(spongehah) Hyper only supports http +func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + // TODO(hah) Hyper only supports http isHttpPrefix := strings.HasPrefix(urlStr, "http://") isHttpsPrefix := strings.HasPrefix(urlStr, "https://") if !isHttpPrefix && !isHttpsPrefix { @@ -121,9 +90,6 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R if !validMethod(method) { return nil, fmt.Errorf("net/http: invalid method %q", method) } - if ctx == nil { - return nil, errors.New("net/http: nil Context") - } u, err := url.Parse(urlStr) if err != nil { return nil, err @@ -135,7 +101,6 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) req := &Request{ - ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", @@ -228,24 +193,6 @@ func (r *Request) isReplayable() bool { return false } -// Context returns the request's context. To change the context, use -// Clone or WithContext. -// -// The returned context is always non-nil; it defaults to the -// background context. -// -// For outgoing client requests, the context controls cancellation. -// -// For incoming server requests, the context is canceled when the -// client's connection closes, the request is canceled (with HTTP/2), -// or when the ServeHTTP method returns. -func (r *Request) Context() context.Context { - if r.ctx != nil { - return r.ctx - } - return context.Background() -} - // AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, // AddCookie does not attach more than one Cookie header field. That // means all cookies, if any, are written into the same line, @@ -300,7 +247,11 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype } // Send it! sendTask := client.Send(hyperReq) - sendTask.SetUserdata(c.Pointer(taskData)) + if sendTask == nil { + println("############### write: sendTask is nil") + return errors.New("failed to send the request") + } + sendTask.SetUserdata(c.Pointer(taskData), nil) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { err = errors.New("failed to send the request") @@ -424,7 +375,7 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header, treq *tra // Wait for 100-continue if expected. if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { - hyperReq.OnInformational(printInformational, nil) + hyperReq.OnInformational(printInformational, nil, nil) } // Write body and trailer diff --git a/x/net/http/response.go b/x/net/http/response.go index a3a96fc..da7c3e4 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -81,13 +81,19 @@ func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { select { case taskData.resc <- responseAndError{res: r}: case <-taskData.callerGone: - readLoopDefer(pc, true) + if debugSwitch { + println("############### checkRespBody callerGone") + } + closeAndRemoveIdleConn(pc, true) return true } // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) - readLoopDefer(pc, false) + if debugSwitch { + println("############### checkRespBody return") + } + closeAndRemoveIdleConn(pc, false) return true } return false @@ -97,6 +103,17 @@ func (r *Response) wrapRespBody(taskData *taskData) { body := &bodyEOFSignal{ body: r.Body, earlyCloseFn: func() error { + // If the response body is closed prematurely, + // the hyperBody needs to be recycled and the persistConn needs to be handled. + taskData.closeHyperBody() + select { + case <-taskData.pc.closech: + taskData.pc.t.removeIdleConn(taskData.pc) + default: + } + replaced := taskData.pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + taskData.pc.alive = taskData.pc.alive && + replaced && taskData.pc.tryPutIdleConn() return nil }, fn: func(err error) error { @@ -110,7 +127,7 @@ func (r *Response) wrapRespBody(taskData *taskData) { }, } r.Body = body - // TODO(spongehah) gzip(wrapRespBody) + // TODO(hah) gzip(wrapRespBody): The compress/gzip library still has a bug. An exception occurs when calling gzip.NewReader(). //if taskData.addedGzip && EqualFold(r.Header.Get("Content-Encoding"), "gzip") { // println("gzip reader") // r.Body = &gzipReader{body: body} diff --git a/x/net/http/server.go b/x/net/http/server.go index 5c4c58d..f38cbd0 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -10,10 +10,3 @@ package http // size is anyway. (if we have the bytes on the machine, we might as // well read them) const maxPostHandlerReadBytes = 256 << 10 - -type readResult struct { - _ incomparable - n int - err error - b byte // byte read, if n == 1 -} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 818fb3c..12f3d70 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -28,7 +28,6 @@ type transferReader struct { ContentLength int64 Chunked bool Close bool - Trailer Header } // parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. @@ -151,10 +150,6 @@ func readTransfer(msg any, r io.ReadCloser) (err error) { t.ContentLength = realLength } - // TODO(spongehah) Trailer(readTransfer) - // Trailer - //t.Trailer, err = fixTrailer(t.Header, t.Chunked) - // If there is no Content-Length or chunked Transfer-Encoding on a *Response // and the status is not 1xx, 204 or 304, then the body is unbounded. // See RFC 7230, section 3.3. @@ -301,48 +296,6 @@ func parseContentLength(cl string) (int64, error) { } -// Parse the trailer header. -func fixTrailer(header Header, chunked bool) (Header, error) { - vv, ok := header["Trailer"] - if !ok { - return nil, nil - } - if !chunked { - // Trailer and no chunking: - // this is an invalid use case for trailer header. - // Nevertheless, no error will be returned and we - // let users decide if this is a valid HTTP message. - // The Trailer header will be kept in Response.Header - // but not populate Response.Trailer. - // See issue #27197. - return nil, nil - } - header.Del("Trailer") - - trailer := make(Header) - var err error - for _, v := range vv { - foreachHeaderElement(v, func(key string) { - key = CanonicalHeaderKey(key) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - if err == nil { - err = badStringError("bad trailer key", key) - return - } - } - trailer[key] = nil - }) - } - if err != nil { - return nil, err - } - if len(trailer) == 0 { - return nil, nil - } - return trailer, nil -} - // body turns a Reader into a ReadCloser. // Close ensures that the body has been fully read // and then reads the trailer if necessary. @@ -387,16 +340,6 @@ func (b *body) readLocked(p []byte) (n int, err error) { b.sawEOF = true // Chunked case. Read the trailer. if b.hdr != nil { - // TODO(spongehah) Trailer(b.readLocked) - //if e := b.readTrailer(); e != nil { - // err = e - // // Something went wrong in the trailer, we must not allow any - // // further reads of any kind to succeed from body, nor any - // // subsequent requests on the server connection. See - // // golang.org/issue/12027 - // b.sawEOF = false - // b.closed = true - //} b.hdr = nil } else { // If the server declared the Content-Length, our body is a LimitedReader @@ -634,7 +577,6 @@ func (r *Request) writeHeader(reqHeaders *hyper.Headers) error { // 'Content-Length' and 'Transfer-Encoding:chunked' are already handled by hyper // Write Trailer header - // TODO(spongehah) Trailer(writeHeader) return nil } @@ -682,7 +624,7 @@ func (r *Request) writeBody(hyperReq *hyper.Request, treq *transportRequest) err buf: buf, treq: treq, } - hyperReqBody.SetUserdata(c.Pointer(reqData)) + hyperReqBody.SetUserdata(c.Pointer(reqData), nil) hyperReqBody.SetDataFunc(setPostData) hyperReq.SetBody(hyperReqBody) } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 8075133..e47bd2a 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "hash/fnv" "io" "log" "net/url" @@ -27,7 +28,6 @@ import ( // as directed by the environment variables HTTP_PROXY, HTTPS_PROXY // and NO_PROXY (or the lowercase versions thereof). var DefaultTransport RoundTripper = &Transport{ - //Proxy: ProxyFromEnvironment, Proxy: nil, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -36,6 +36,7 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 +const _SC_NPROCESSORS_ONLN c.Int = 58 // Debug switch provided for developers const ( @@ -69,11 +70,10 @@ type Transport struct { MaxConnsPerHost int IdleConnTimeout time.Duration - // libuv and hyper related - loopInitOnce sync.Once - loop *libuv.Loop - async *libuv.Async - exec *hyper.Executor + loopsMu sync.Mutex + loops []*clientEventLoop + isClosing atomic.Bool + //curLoop atomic.Uint32 } // A cancelKey is the key of the reqCanceler map. @@ -183,6 +183,9 @@ func (tr *transportRequest) setError(err error) { func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { if err := t.tryPutIdleConn(pconn); err != nil { + if debugSwitch { + println("############### putOrCloseIdleConn: close") + } pconn.close(err) } } @@ -274,6 +277,9 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { t.idleLRU.add(pconn) if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { oldest := t.idleLRU.removeOldest() + if debugSwitch { + println("############### tryPutIdleConn: removeOldest") + } oldest.close(errTooManyIdle) t.removeIdleConnLocked(oldest) } @@ -287,7 +293,7 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) } else { pconn.idleTimer = &libuv.Timer{} - libuv.InitTimer(t.loop, pconn.idleTimer) + libuv.InitTimer(pconn.eventLoop.loop, pconn.idleTimer) (*libuv.Handle)(c.Pointer(pconn.idleTimer)).SetData(c.Pointer(pconn)) pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) } @@ -343,7 +349,9 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { // See whether this connection has been idle too long, considering // only the wall time (the Round(0)), in case this is a laptop or VM // coming out of suspend with previously cached idle connections. - tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + // FIXME: Round() is not supported in llgo + //tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + tooOld := !oldTime.IsZero() && pconn.idleAt.Before(oldTime) if tooOld { // Async cleanup. Launch in its own goroutine (as if a // time.AfterFunc called it); it acquires idleMu, which we're @@ -403,9 +411,10 @@ func (t *Transport) removeIdleConn(pconn *persistConn) bool { // t.idleMu must be held. func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { - if pconn.idleTimer != nil { + if pconn.idleTimer != nil && (*libuv.Handle)(c.Pointer(pconn.idleTimer)).IsClosing() == 0 { pconn.idleTimer.Stop() (*libuv.Handle)(c.Pointer(pconn.idleTimer)).Close(nil) + pconn.idleTimer = nil } t.idleLRU.remove(pconn) key := pconn.cacheKey @@ -467,13 +476,14 @@ func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { return true } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { +func (t *Transport) connectMethodForRequest(treq *transportRequest, loop *clientEventLoop) (cm connectMethod, err error) { cm.targetScheme = treq.URL.Scheme cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } cm.onlyH1 = treq.requiresHTTP1() + cm.eventLoop = loop return cm, err } @@ -524,25 +534,56 @@ func (t *Transport) cancelRequest(key cancelKey, err error) bool { return cancel != nil } -func (t *Transport) close(err error) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - t.closeLocked(err) +func (t *Transport) Close() { + if t != nil && !t.isClosing.Swap(true) { + t.CloseIdleConnections() + for _, el := range t.loops { + el.Close() + } + } } -func (t *Transport) closeLocked(err error) { - if err != nil { - fmt.Println(err) - } - if t.loop != nil { - t.loop.Close() - } - if t.async != nil { - t.async.Close(nil) +type clientEventLoop struct { + // libuv and hyper related + loop *libuv.Loop + async *libuv.Async + exec *hyper.Executor + isRunning atomic.Bool + isClosing atomic.Bool +} + +func (el *clientEventLoop) Close() { + if el != nil && !el.isClosing.Swap(true) { + if el.loop != nil && (*libuv.Handle)(c.Pointer(el.loop)).IsClosing() == 0 { + el.loop.Close() + el.loop = nil + } + if el.async != nil && (*libuv.Handle)(c.Pointer(el.async)).IsClosing() == 0 { + el.async.Close(nil) + el.async = nil + } + if el.exec != nil { + el.exec.Free() + el.exec = nil + } } - if t.exec != nil { - t.exec.Free() +} + +func (el *clientEventLoop) run() { + if el.isRunning.Load() { + return } + + el.loop.Async(el.async, nil) + + checker := &libuv.Idle{} + libuv.InitIdle(el.loop, checker) + (*libuv.Handle)(c.Pointer(checker)).SetData(c.Pointer(el)) + checker.Start(readWriteLoop) + + go el.loop.Run(libuv.RUN_DEFAULT) + + el.isRunning.Store(true) } // ---------------------------------------------------------- @@ -556,26 +597,65 @@ func getMilliseconds(deadline time.Time) uint64 { return uint64(milliseconds) } +var cpuCount int + +func init() { + cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) + if cpuCount <= 0 { + cpuCount = 4 + } +} + +func (t *Transport) getOrInitClientEventLoop(i uint32) *clientEventLoop { + if el := t.loops[i]; el != nil { + return el + } + + eventLoop := &clientEventLoop{ + loop: libuv.LoopNew(), + async: &libuv.Async{}, + exec: hyper.NewExecutor(), + } + + eventLoop.run() + + t.loops[i] = eventLoop + return eventLoop +} + +func (t *Transport) getClientEventLoop(req *Request) *clientEventLoop { + t.loopsMu.Lock() + defer t.loopsMu.Unlock() + if t.loops == nil { + t.loops = make([]*clientEventLoop, cpuCount) + } + + key := t.getLoopKey(req) + h := fnv.New32a() + h.Write([]byte(key)) + hashcode := h.Sum32() + + return t.getOrInitClientEventLoop(hashcode % uint32(cpuCount)) + //i := (t.curLoop.Add(1) - 1) % uint32(cpuCount) + //return t.getOrInitClientEventLoop(i) +} + +func (t *Transport) getLoopKey(req *Request) string { + proxyStr := "" + if t.Proxy != nil { + proxyURL, _ := t.Proxy(req) + proxyStr = proxyURL.String() + } + return req.URL.String() + proxyStr +} + func (t *Transport) RoundTrip(req *Request) (*Response, error) { if debugSwitch { println("############### RoundTrip start") defer println("############### RoundTrip end") } - t.loopInitOnce.Do(func() { - println("############### init loop") - t.loop = libuv.LoopNew() - t.async = &libuv.Async{} - t.exec = hyper.NewExecutor() - - t.loop.Async(t.async, nil) - checker := &libuv.Check{} - libuv.InitCheck(t.loop, checker) - (*libuv.Handle)(c.Pointer(checker)).SetData(c.Pointer(t)) - checker.Start(readWriteLoop) - - go t.loop.Run(libuv.RUN_DEFAULT) - }) + eventLoop := t.getClientEventLoop(req) // If timeout is set, start the timer var didTimeout func() bool @@ -583,7 +663,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // Only the first request will initialize the timer if req.timer == nil && !req.deadline.IsZero() { req.timer = &libuv.Timer{} - libuv.InitTimer(t.loop, req.timer) + libuv.InitTimer(eventLoop.loop, req.timer) ch := &timeoutData{ timeoutch: req.timeoutch, taskData: nil, @@ -598,7 +678,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { stopTimer = func() { close(req.timeoutch) req.timer.Stop() - (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) + if (*libuv.Handle)(c.Pointer(req.timer)).IsClosing() == 0 { + (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) + } if debugSwitch { println("############### timer close") } @@ -608,7 +690,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { stopTimer = nop } - resp, err := t.doRoundTrip(req) + resp, err := t.doRoundTrip(req, eventLoop) if err != nil { stopTimer() return nil, err @@ -624,7 +706,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return resp, nil } -func (t *Transport) doRoundTrip(req *Request) (*Response, error) { +func (t *Transport) doRoundTrip(req *Request, loop *clientEventLoop) (*Response, error) { if debugSwitch { println("############### doRoundTrip start") defer println("############### doRoundTrip end") @@ -687,12 +769,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } for { - //select { - //case <-ctx.Done(): - // req.closeBody() - // return nil, ctx.Err() - //default: - //} select { case <-req.timeoutch: req.closeBody() @@ -703,7 +779,7 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { // treq gets modified by roundTrip, so we need to recreate for each retry. //treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} treq := &transportRequest{Request: req, cancelKey: cancelKey} - cm, err := t.connectMethodForRequest(treq) + cm, err := t.connectMethodForRequest(treq, loop) if err != nil { req.closeBody() return nil, err @@ -716,6 +792,7 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { pconn, err := t.getConn(treq, cm) if err != nil { + println("################# getConn err != nil") t.setReqCanceler(cancelKey, nil) req.closeBody() return nil, err @@ -827,10 +904,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // what caused w.err; if so, prefer to return the // cancellation error (see golang.org/issue/16049). select { - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() case <-req.timeoutch: if debugSwitch { println("############### getConn: timeoutch") @@ -977,68 +1050,23 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * writeLoopDone: make(chan struct{}, 1), alive: true, chunkAsync: &libuv.Async{}, + eventLoop: cm.eventLoop, } - t.loop.Async(pconn.chunkAsync, readyToRead) + cm.eventLoop.loop.Async(pconn.chunkAsync, readyToRead) - //trace := httptrace.ContextClientTrace(ctx) - //wrapErr := func(err error) error { - // if cm.proxyURL != nil { - // // Return a typed error, per Issue 16997 - // return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} - // } - // return err - //} - // - //if cm.scheme() == "https" && t.hasCustomTLSDialer() { - // var err error - // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) - // if err != nil { - // return nil, wrapErr(err) - // } - // if tc, ok := pconn.conn.(*tls.Conn); ok { - // // Handshake here, in case DialTLS didn't. TLSNextProto below - // // depends on it for knowing the connection state. - // if trace != nil && trace.TLSHandshakeStart != nil { - // trace.TLSHandshakeStart() - // } - // if err := tc.HandshakeContext(ctx); err != nil { - // go pconn.conn.Close() - // if trace != nil && trace.TLSHandshakeDone != nil { - // trace.TLSHandshakeDone(tls.ConnectionState{}, err) - // } - // return nil, err - // } - // cs := tc.ConnectionState() - // if trace != nil && trace.TLSHandshakeDone != nil { - // trace.TLSHandshakeDone(cs, nil) - // } - // pconn.tlsState = &cs - // } - //} else { - //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(cm.addr()) + conn, err := t.dial(cm) if err != nil { return nil, err } pconn.conn = conn - //if cm.scheme() == "https" { - // var firstTLSHost string - // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { - // return nil, wrapErr(err) - // } - // if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { - // return nil, wrapErr(err) - // } - //} - //} select { case <-timeoutch: conn.Close() return default: } - // TODO(spongehah) Proxy(https/sock5)(t.dialConn) + // TODO(hah) Proxy(https/sock5)(t.dialConn) // Proxy setup. switch { case cm.proxyURL == nil: @@ -1054,41 +1082,14 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // case cm.targetScheme == "https": } - //if cm.proxyURL != nil && cm.targetScheme == "https" { - // if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { - // return nil, err - // } - //} - // - //if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { - // if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - // alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) - // if e, ok := alt.(erringRoundTripper); ok { - // // pconn.conn was closed by next (http2configureTransports.upgradeFn). - // return nil, e.RoundTripErr() - // } - // return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil - // } - //} - pconn.closeErr = errReadLoopExiting - pconn.tryPutIdleConn = func() bool { - if err := pconn.t.tryPutIdleConn(pconn); err != nil { - pconn.closeErr = err - //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - //} - return false - } - //if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - //} - return true - } select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") + if debugSwitch { + println("############### dialConn: timeoutch") + } pconn.close(err) return nil, err default: @@ -1096,11 +1097,12 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * return pconn, nil } -func (t *Transport) dial(addr string) (*connData, error) { +func (t *Transport) dial(cm connectMethod) (*connData, error) { if debugSwitch { println("############### dial start") defer println("############### dial end") } + addr := cm.addr() host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -1108,8 +1110,8 @@ func (t *Transport) dial(addr string) (*connData, error) { conn := new(connData) - libuv.InitTcp(t.loop, &conn.TcpHandle) - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) + libuv.InitTcp(cm.eventLoop.loop, &conn.tcpHandle) + (*libuv.Handle)(c.Pointer(&conn.tcpHandle)).SetData(c.Pointer(conn)) var hints cnet.AddrInfo c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) @@ -1122,8 +1124,8 @@ func (t *Transport) dial(addr string) (*connData, error) { return nil, fmt.Errorf("getaddrinfo error\n") } - (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) - status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) + (*libuv.Req)(c.Pointer(&conn.connectReq)).SetData(c.Pointer(conn)) + status = libuv.TcpConnect(&conn.connectReq, &conn.tcpHandle, res.Addr, onConnect) if status != 0 { return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } @@ -1179,33 +1181,32 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err resc: resc, } - if pc.client == nil && !pc.isReused() { - // Hookup the IO - hyperIo := newHyperIo(pc.conn) - // We need an executor generally to poll futures - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(pc.t.exec) - // send the handshake - handshakeTask := hyper.Handshake(hyperIo, opts) - taskData.taskId = handshake - handshakeTask.SetUserdata(c.Pointer(taskData)) - // Send the request to readWriteLoop(). - pc.t.exec.Push(handshakeTask) - } else { - taskData.taskId = read - err = req.write(pc.client, taskData, pc.t.exec) - if err != nil { - writeErrCh <- err - } - } + //if pc.client == nil && !pc.isReused() { + // Hookup the IO + hyperIo := newHyperIo(pc.conn) + // We need an executor generally to poll futures + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(pc.eventLoop.exec) + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + taskData.taskId = handshake + handshakeTask.SetUserdata(c.Pointer(taskData), nil) + // Send the request to readWriteLoop(). + pc.eventLoop.exec.Push(handshakeTask) + //} else { + // println("############### roundTrip: pc.client != nil") + // taskData.taskId = read + // err = req.write(pc.client, taskData, pc.eventLoop.exec) + // if err != nil { + // writeErrCh <- err + // pc.close(err) + // } + //} // Wake up libuv. Loop - pc.t.async.Send() + pc.eventLoop.async.Send() - //var respHeaderTimer <-chan time.Time - //cancelChan := req.Request.Cancel - //ctxDoneChan := req.Context().Done() timeoutch := req.timeoutch pcClosed := pc.closech canceled := false @@ -1221,6 +1222,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err println("############### roundTrip: writeErrch") } if err != nil { + if debugSwitch { + println("############### roundTrip: writeErrch err != nil") + } pc.close(fmt.Errorf("write error: %w", err)) if pc.conn.nwrite == startBytesWritten { err = nothingWrittenError{err} @@ -1247,13 +1251,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - //case <-cancelChan: - // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) - // cancelChan = nil - //case <-ctxDoneChan: - // canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) - // cancelChan = nil - // ctxDoneChan = nil case <-timeoutch: if debugSwitch { println("############### roundTrip: timeoutch") @@ -1267,21 +1264,21 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. -func readWriteLoop(checker *libuv.Check) { - t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) +func readWriteLoop(checker *libuv.Idle) { + eventLoop := (*clientEventLoop)((*libuv.Handle)(c.Pointer(checker)).GetData()) // The polling state machine! Poll all ready tasks and act on them... - task := t.exec.Poll() + task := eventLoop.exec.Poll() for task != nil { if debugSwitch { println("############### polling") } - t.handleTask(task) - task = t.exec.Poll() + eventLoop.handleTask(task) + task = eventLoop.exec.Poll() } } -func (t *Transport) handleTask(task *hyper.Task) { +func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { taskData := (*taskData)(task.Userdata()) if taskData == nil { // A background task for hyper_client completed... @@ -1293,7 +1290,10 @@ func (t *Transport) handleTask(task *hyper.Task) { // If original taskId is set, we need to check it err = checkTaskType(task, taskData) if err != nil { - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask: checkTaskType err != nil") + } + closeAndRemoveIdleConn(pc, true) return } switch taskData.taskId { @@ -1313,13 +1313,16 @@ func (t *Transport) handleTask(task *hyper.Task) { pc.client = (*hyper.ClientConn)(task.Value()) task.Free() - // TODO(spongehah) Proxy(writeLoop) + // TODO(hah) Proxy(writeLoop) taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, t.exec) + err = taskData.req.Request.write(pc.client, taskData, eventLoop.exec) if err != nil { //pc.writeErrCh <- err // to the body reader, which might recycle us taskData.writeErrCh <- err // to the roundTrip function + if debugSwitch { + println("############### handleTask: write err != nil") + } pc.close(err) return } @@ -1332,6 +1335,20 @@ func (t *Transport) handleTask(task *hyper.Task) { println("############### read") } + pc.tryPutIdleConn = func() bool { + if err := pc.t.tryPutIdleConn(pc); err != nil { + pc.closeErr = err + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() @@ -1340,7 +1357,10 @@ func (t *Transport) handleTask(task *hyper.Task) { if pc.numExpectedResponses == 0 { pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask: numExpectedResponses == 0") + } + closeAndRemoveIdleConn(pc, true) return } //pc.mu.Unlock() @@ -1361,20 +1381,25 @@ func (t *Transport) handleTask(task *hyper.Task) { hyperResp.Free() if err != nil { + pc.bodyChunk.closeWithError(err) + taskData.closeHyperBody() select { case taskData.resc <- responseAndError{err: err}: case <-taskData.callerGone: - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask read: callerGone") + } + closeAndRemoveIdleConn(pc, true) return } - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask: read err != nil") + } + closeAndRemoveIdleConn(pc, true) return } - dataTask := taskData.hyperBody.Data() taskData.taskId = readBodyChunk - dataTask.SetUserdata(c.Pointer(taskData)) - t.exec.Push(dataTask) if !taskData.req.deadline.IsZero() { (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData @@ -1391,21 +1416,18 @@ func (t *Transport) handleTask(task *hyper.Task) { resp.wrapRespBody(taskData) - // FIXME: Waiting for the channel bug to be fixed - //select { - //case taskData.resc <- responseAndError{res: resp}: - //case <-taskData.callerGone: - // // defer - // readLoopDefer(pc, true) - // return - //} select { + case taskData.resc <- responseAndError{res: resp}: case <-taskData.callerGone: - readLoopDefer(pc, true) + // defer + if debugSwitch { + println("############### handleTask read: callerGone 2") + } + pc.bodyChunk.Close() + taskData.closeHyperBody() + closeAndRemoveIdleConn(pc, true) return - default: } - taskData.resc <- responseAndError{res: resp} if debugReadWriteLoop { println("############### read end") @@ -1433,14 +1455,16 @@ func (t *Transport) handleTask(task *hyper.Task) { // taskType == taskEmpty (check in checkTaskType) task.Free() - taskData.hyperBody.Free() - taskData.hyperBody = nil - pc.bodyChunk.Close() - replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.bodyChunk.closeWithError(io.EOF) + taskData.closeHyperBody() + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool pc.alive = pc.alive && replaced && pc.tryPutIdleConn() - readLoopDefer(pc, false) + if debugSwitch { + println("############### handleTask readBodyChunk: alive: ", pc.alive) + } + closeAndRemoveIdleConn(pc, false) if debugReadWriteLoop { println("############### readBodyChunk end [empty]") @@ -1449,18 +1473,20 @@ func (t *Transport) handleTask(task *hyper.Task) { } func readyToRead(aysnc *libuv.Async) { - println("############### AsyncCb: readyToRead") taskData := (*taskData)(aysnc.GetData()) dataTask := taskData.hyperBody.Data() - dataTask.SetUserdata(c.Pointer(taskData)) - taskData.pc.t.exec.Push(dataTask) + dataTask.SetUserdata(c.Pointer(taskData), nil) + taskData.pc.eventLoop.exec.Push(dataTask) } -// readLoopDefer Replace the defer function of readLoop in stdlib -func readLoopDefer(pc *persistConn, force bool) { +// closeAndRemoveIdleConn Replace the defer function of readLoop in stdlib +func closeAndRemoveIdleConn(pc *persistConn, force bool) { if pc.alive == true && !force { return } + if debugSwitch { + println("############### closeAndRemoveIdleConn, force:", force) + } pc.close(pc.closeErr) pc.t.removeIdleConn(pc) } @@ -1468,13 +1494,14 @@ func readLoopDefer(pc *persistConn, force bool) { // ---------------------------------------------------------- type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr + tcpHandle libuv.Tcp + connectReq libuv.Connect + readBuf libuv.Buf + readBufFilled uintptr nwrite int64 // bytes written(Replaced from persistConn's nwrite) - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker + readWaker *hyper.Waker + writeWaker *hyper.Waker + isClosing atomic.Bool } type taskData struct { @@ -1497,24 +1524,32 @@ const ( readBodyChunk ) -func (conn *connData) Close() error { - if conn == nil { - return nil - } - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil +func (conn *connData) Close() { + if conn != nil && !conn.isClosing.Swap(true) { + if conn.readWaker != nil { + conn.readWaker.Free() + conn.readWaker = nil + } + if conn.writeWaker != nil { + conn.writeWaker.Free() + conn.writeWaker = nil + } + //if conn.readBuf.Base != nil { + // c.Free(c.Pointer(conn.readBuf.Base)) + // conn.readBuf.Base = nil + //} + if (*libuv.Handle)(c.Pointer(&conn.tcpHandle)).IsClosing() == 0 { + (*libuv.Handle)(c.Pointer(&conn.tcpHandle)).Close(nil) + } + conn = nil } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil +} + +func (d *taskData) closeHyperBody() { + if d.hyperBody != nil { + d.hyperBody.Free() + d.hyperBody = nil } - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) - return nil } // onConnect is the libuv callback for a successful connection @@ -1524,24 +1559,28 @@ func onConnect(req *libuv.Connect, status c.Int) { defer println("############### connect end") } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - if status < 0 { - c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + c.Fprintf(c.Stderr, c.Str("connect error: %s\n"), c.GoString(libuv.Strerror(libuv.Errno(status)))) + conn.Close() return } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(allocBuffer, onRead) + + // Keep-Alive + conn.tcpHandle.KeepAlive(1, 60) + + (*libuv.Stream)(c.Pointer(&conn.tcpHandle)).StartRead(allocBuffer, onRead) } // allocBuffer allocates a buffer for reading from a socket func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { conn := (*connData)(handle.GetData()) - if conn.ReadBuf.Base == nil { - conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) - //base := make([]byte, suggestedSize) - //conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Pointer(&base[0])), c.Uint(suggestedSize)) - conn.ReadBufFilled = 0 + if conn.readBuf.Base == nil { + //conn.readBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + base := make([]byte, suggestedSize) + conn.readBuf = libuv.InitBuf((*c.Char)(c.Pointer(&base[0])), c.Uint(suggestedSize)) + conn.readBufFilled = 0 } - *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) + *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.readBuf.Base))+conn.readBufFilled)), c.Uint(suggestedSize-conn.readBufFilled)) } // onRead is the libuv callback for reading from a socket @@ -1549,38 +1588,39 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) if nread > 0 { - conn.ReadBufFilled += uintptr(nread) + conn.readBufFilled += uintptr(nread) } - if conn.ReadWaker != nil { + if conn.readWaker != nil { // Wake up the pending read operation of Hyper - conn.ReadWaker.Wake() - conn.ReadWaker = nil + conn.readWaker.Wake() + conn.readWaker = nil } } // readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { conn := (*connData)(userdata) - if conn.ReadBufFilled > 0 { + if conn.readBufFilled > 0 { var toCopy uintptr - if bufLen < conn.ReadBufFilled { + if bufLen < conn.readBufFilled { toCopy = bufLen } else { - toCopy = conn.ReadBufFilled + toCopy = conn.readBufFilled } // Copy data from read buffer to Hyper's buffer - c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + c.Memcpy(c.Pointer(buf), c.Pointer(conn.readBuf.Base), toCopy) // Move remaining data to the beginning of the buffer - c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + c.Memmove(c.Pointer(conn.readBuf.Base), c.Pointer(uintptr(c.Pointer(conn.readBuf.Base))+toCopy), conn.readBufFilled-toCopy) // Update the amount of filled buffer - conn.ReadBufFilled -= toCopy + conn.readBufFilled -= toCopy return toCopy } - if conn.ReadWaker != nil { - conn.ReadWaker.Free() + if conn.readWaker != nil { + conn.readWaker.Free() } - conn.ReadWaker = ctx.Waker() + conn.readWaker = ctx.Waker() + println("############### readCallBack: IoPending") return hyper.IoPending } @@ -1588,10 +1628,10 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin // Callback function called after a write operation completes func onWrite(req *libuv.Write, status c.Int) { conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - if conn.WriteWaker != nil { + if conn.writeWaker != nil { // Wake up the pending write operation - conn.WriteWaker.Wake() - conn.WriteWaker = nil + conn.writeWaker.Wake() + conn.writeWaker = nil } } @@ -1602,16 +1642,17 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui req := &libuv.Write{} (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.tcpHandle)), &initBuf, 1, onWrite) if ret >= 0 { conn.nwrite += int64(bufLen) return bufLen } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() + if conn.writeWaker != nil { + conn.writeWaker.Free() } - conn.WriteWaker = ctx.Waker() + conn.writeWaker = ctx.Waker() + println("############### writeCallBack: IoPending") return hyper.IoPending } @@ -1630,14 +1671,14 @@ func onTimeout(timer *libuv.Timer) { pc := taskData.pc pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) - readLoopDefer(pc, true) + closeAndRemoveIdleConn(pc, true) } } // newHyperIo creates a new IO with read and write callbacks func newHyperIo(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() - hyperIo.SetUserdata(c.Pointer(connData)) + hyperIo.SetUserdata(c.Pointer(connData), nil) hyperIo.SetRead(readCallBack) hyperIo.SetWrite(writeCallBack) return hyperIo @@ -1670,8 +1711,16 @@ func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { task.Free() if curTaskId == handshake || curTaskId == read { taskData.writeErrCh <- err + if debugSwitch { + println("############### checkTaskType: writeErrCh") + } taskData.pc.close(err) } + if taskData.pc.bodyChunk != nil { + taskData.pc.bodyChunk.Close() + taskData.pc.bodyChunk = nil + } + taskData.closeHyperBody() taskData.pc.alive = false } return @@ -1685,6 +1734,7 @@ func fail(err *hyper.Error, taskId taskId) error { errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) errDetails := unsafe.SliceData(errBuf[:errLen]) details := c.GoString(errDetails) + fmt.Println(details) // clean up the error err.Free() @@ -1837,7 +1887,9 @@ type persistConn struct { // If it's non-nil, the rest of the fields are unused. alt RoundTripper - t *Transport + t *Transport + eventLoop *clientEventLoop + cacheKey connectMethodKey conn *connData //tlsState *tls.ConnectionState @@ -1876,6 +1928,9 @@ type persistConn struct { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { + if debugSwitch { + println("############### CloseIdleConnections") + } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t.idleMu.Lock() m := t.idleConn @@ -1888,12 +1943,16 @@ func (t *Transport) CloseIdleConnections() { pconn.close(errCloseIdleConns) } } + //if t2 := t.h2transport; t2 != nil { // t2.CloseIdleConnections() //} } func (pc *persistConn) cancelRequest(err error) { + if debugSwitch { + println("############### cancelRequest") + } pc.mu.Lock() defer pc.mu.Unlock() pc.canceledErr = err @@ -1938,8 +1997,14 @@ func (pc *persistConn) closeLocked(err error) { } close(pc.closech) close(pc.writeLoopDone) - pc.client.Free() - pc.chunkAsync.Close(nil) + if pc.client != nil { + pc.client.Free() + pc.client = nil + } + if pc.chunkAsync != nil && pc.chunkAsync.IsClosing() == 0 { + pc.chunkAsync.Close(nil) + pc.chunkAsync = nil + } } } pc.mutateHeaderFunc = nil @@ -2096,10 +2161,16 @@ func (pc *persistConn) closeConnIfStillIdleLocked() { return } t.removeIdleConnLocked(pc) + if debugSwitch { + println("############### closeConnIfStillIdleLocked") + } pc.close(errIdleConnTimeout) } func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { + if debugSwitch { + println("############### readLoopPeekFailLocked") + } if pc.closed != nil { return } @@ -2117,7 +2188,7 @@ func (pc *persistConn) setExtraHeaders(req *transportRequest) bool { // uncompress the gzip stream if we were the layer that // requested it. requestedGzip := false - // TODO(spongehah) gzip(pc.roundTrip) + // TODO(hah) gzip(pc.roundTrip): The compress/gzip library still has a bug. An exception occurs when calling gzip.NewReader(). //if !pc.t.DisableCompression && // req.Header.Get("Accept-Encoding") == "" && // req.Header.Get("Range") == "" && @@ -2190,6 +2261,8 @@ type connectMethod struct { // be reused for different targetAddr values. targetAddr string onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 + + eventLoop *clientEventLoop } // connectMethodKey is the map key version of connectMethod, with a