From 3f94d3f19e32f2f9179c3c0c0c4e241dfa8bce98 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 5 Sep 2024 18:07:50 +0800 Subject: [PATCH] refactor(x/net/http): Temporarily fix pipe write stuck Signed-off-by: hackerchai --- x/net/http/request.go | 32 +++++-- x/net/http/response.go | 35 ++++++- x/net/http/server.go | 198 ++++++++++++++++++++++++++-------------- x/net/http/servermux.go | 3 +- 4 files changed, 185 insertions(+), 83 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index 9d458c3..26187ff 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -137,6 +137,12 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { if body != nil { req.Body, conn.bodyWriter = io.Pipe() task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) + taskData := taskData { + body: nil, + conn: conn, + hyperTaskID: taskGetBody, + } + task.SetUserdata(c.Pointer(&taskData), nil) if task != nil { r := conn.executor.Push(task) if r != hyper.OK { @@ -175,6 +181,10 @@ func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, va func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { fmt.Printf("getBodyChunk called\n") writer := (*io.PipeWriter)(userdata) + if writer == nil { + fmt.Printf("writer is nil\n") + return hyper.IterBreak + } buf := chunk.Bytes() len := chunk.Len() bytes := unsafe.Slice(buf, len) @@ -182,13 +192,21 @@ func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { fmt.Printf("Writing %d bytes to response body\n", len) fmt.Printf("Body chunk: %s\n", string(bytes)) - _, err := writer.Write(bytes) - fmt.Printf("Body chunk written\n") - if err != nil { - fmt.Println("Error writing to response body:", err) - writer.Close() - return hyper.IterBreak - } + go func() { + count, err := writer.Write(bytes) + fmt.Printf("Body chunk written: %d bytes\n", count) + if err != nil { + fmt.Println("Error writing to response body:", err) + writer.Close() + } + }() + // count, err := writer.Write(bytes) + // fmt.Printf("Body chunk written: %d bytes\n", count) + // if err != nil { + // fmt.Println("Error writing to response body:", err) + // writer.Close() + // return hyper.IterBreak + // } return hyper.IterContinue } diff --git a/x/net/http/response.go b/x/net/http/response.go index 20e6f01..29c065d 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -16,6 +16,7 @@ type response struct { body []byte channel *hyper.ResponseChannel resp *hyper.Response + request *Request } type body struct { @@ -24,13 +25,29 @@ type body struct { readLen uintptr } +type taskData struct { + body *body + conn *conn + hyperTaskID +} + +type hyperTaskID int + +const ( + taskSetBody hyperTaskID = iota + taskGetBody +) + + + var DefaultChunkSize uintptr = 8192 -func newResponse(channel *hyper.ResponseChannel) *response { +func newResponse(request *Request, channel *hyper.ResponseChannel) *response { fmt.Printf("newResponse called\n") resp := response{ header: make(Header), channel: channel, + request: request, } return &resp } @@ -90,6 +107,12 @@ func (r *response) WriteHeader(statusCode int) { func (r *response) finalize() error { fmt.Printf("finalize called\n") + err := r.request.Body.Close() + if err != nil { + return err + } + fmt.Printf("request body closed\n") + if !r.written { r.WriteHeader(200) } @@ -105,8 +128,13 @@ func (r *response) finalize() error { if body == nil { return fmt.Errorf("failed to create body") } + taskData := taskData{ + body: &bodyData, + conn: nil, + hyperTaskID: taskSetBody, + } body.SetDataFunc(setBodyDataFunc) - body.SetUserdata(unsafe.Pointer(&bodyData), nil) + body.SetUserdata(unsafe.Pointer(&taskData), nil) fmt.Printf("bodyData userdata set\n") fmt.Printf("bodyData set\n") @@ -124,12 +152,13 @@ func (r *response) finalize() error { func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { fmt.Printf("setBodyDataFunc called\n") - body := (*body)(userdata) + body := (*taskData)(userdata).body if body.len > 0 { //debug fmt.Println("<") fmt.Printf("%s", string(body.data)) + fmt.Println("") if body.len > DefaultChunkSize { *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) diff --git a/x/net/http/server.go b/x/net/http/server.go index d006a95..c6aee65 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -33,13 +33,11 @@ type Server struct { Addr string Handler Handler - uvLoop *libuv.Loop - uvServer libuv.Tcp - inShutdown atomic.Bool - http1Opts *hyper.Http1ServerconnOptions - http2Opts *hyper.Http2ServerconnOptions - checkHandle libuv.Check - idleHandle libuv.Idle + uvLoop *libuv.Loop + uvServer libuv.Tcp + inShutdown atomic.Bool + //checkHandle libuv.Check + idleHandle libuv.Idle mu sync.Mutex activeConnections map[*conn]struct{} @@ -51,6 +49,8 @@ type conn struct { eventMask c.Uint readWaker *hyper.Waker writeWaker *hyper.Waker + http1Opts *hyper.Http1ServerconnOptions + http2Opts *hyper.Http2ServerconnOptions isClosing atomic.Bool closedHandles int32 executor *hyper.Executor @@ -119,49 +119,43 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) } - //(*libuv.Stream)(&srv.uvServer).Data = unsafe.Pointer(srv) - (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).SetData(unsafe.Pointer(srv)) + srv.uvServer.Data = unsafe.Pointer(srv) if err := (*libuv.Stream)(&srv.uvServer).Listen(128, onNewConnection); err != 0 { return fmt.Errorf("failed to listen: %v", err) } - if r := libuv.InitCheck(srv.uvLoop, &srv.checkHandle); r != 0 { - fmt.Fprintf(os.Stderr, "Failed to initialize check handler: %d\n", r) - os.Exit(1) - } - - (*libuv.Handle)(unsafe.Pointer(&srv.checkHandle)).SetData(unsafe.Pointer(srv)) - - if r := srv.checkHandle.Start(onCheck); r != 0 { - fmt.Fprintf(os.Stderr, "Failed to start check handler: %d\n", r) - os.Exit(1) - } - - // if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { - // fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) + // if r := libuv.InitCheck(srv.uvLoop, &srv.checkHandle); r != 0 { + // fmt.Fprintf(os.Stderr, "Failed to initialize check handler: %d\n", r) // os.Exit(1) // } - // (*libuv.Handle)(unsafe.Pointer(&srv.idleHandle)).SetData(unsafe.Pointer(srv)) + // (*libuv.Handle)(unsafe.Pointer(&srv.checkHandle)).SetData(unsafe.Pointer(srv)) - // if r := srv.idleHandle.Start(onIdle); r != 0 { - // fmt.Fprintf(os.Stderr, "Failed to start idle handler: %d\n", r) + // if r := srv.checkHandle.Start(onCheck); r != 0 { + // fmt.Fprintf(os.Stderr, "Failed to start check handler: %d\n", r) // os.Exit(1) // } - fmt.Printf("Listening on %s\n", srv.Addr) + if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { + fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) + os.Exit(1) + } - for { - res := srv.uvLoop.Run(libuv.RUN_NOWAIT) - if res < 0 { - fmt.Fprintf(os.Stderr, "uv_loop_run error: %s\n", libuv.Strerror(libuv.Errno(res))) - break - } + (*libuv.Handle)(unsafe.Pointer(&srv.idleHandle)).SetData(unsafe.Pointer(srv)) - if srv.shuttingDown() { - break - } + if r := srv.idleHandle.Start(onIdle); r != 0 { + fmt.Fprintf(os.Stderr, "Failed to start idle handler: %d\n", r) + os.Exit(1) } + + fmt.Printf("Listening on %s\n", srv.Addr) + + res := srv.uvLoop.Run(libuv.RUN_DEFAULT) + if res != 0 { + fmt.Fprintf(os.Stderr, "Error in event loop: %v\n", res) + os.Exit(1) + } + return nil } @@ -250,7 +244,6 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { io := createIo(conn) service := hyper.ServiceNew(serverCallback) service.SetUserdata(unsafe.Pointer(userdata), nil) - http1Opts := hyper.Http1ServerconnOptionsNew(conn.executor) if http1Opts == nil { fmt.Fprintf(os.Stderr, "Failed to create http1_opts\n") @@ -261,7 +254,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { fmt.Fprintf(os.Stderr, "Failed to set header read timeout for http1_opts\n") os.Exit(1) } - srv.http1Opts = http1Opts + conn.http1Opts = http1Opts http2Opts := hyper.Http2ServerconnOptionsNew(conn.executor) if http2Opts == nil { @@ -278,7 +271,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { fmt.Fprintf(os.Stderr, "Failed to set keep alive timeout for http2_opts\n") os.Exit(1) } - srv.http2Opts = http2Opts + conn.http2Opts = http2Opts serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) conn.executor.Push(serverconn) @@ -289,29 +282,30 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { } } -func onCheck(handle *libuv.Check) { - //fmt.Println("onCheck called") - srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) - for conn := range srv.activeConnections { - if conn.executor != nil { - task := conn.executor.Poll() - for task != nil { - srv.handleTask(task) - task = conn.executor.Poll() - } - } - } - - if srv.shuttingDown() { - fmt.Println("Shutdown initiated, cleaning up...") - handle.Stop() - } -} +// func onCheck(handle *libuv.Check) { +// //fmt.Println("onCheck called") +// srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) +// for conn := range srv.activeConnections { +// if conn.executor != nil { +// task := conn.executor.Poll() +// for task != nil { +// srv.handleTask(task) +// task = conn.executor.Poll() +// } +// } +// } + +// if srv.shuttingDown() { +// fmt.Println("Shutdown initiated, cleaning up...") +// handle.Stop() +// } +// } func onIdle(handle *libuv.Idle) { //fmt.Println("onIdle called") srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) for conn := range srv.activeConnections { + //fmt.Println("onIdle conn called") if conn.executor != nil { task := conn.executor.Poll() for task != nil { @@ -341,16 +335,35 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - res := newResponse(channel) + res := newResponse(req, channel) fmt.Printf("Response created\n") - userData.server.Handler.ServeHTTP(res, req) + go func() { + userData.server.Handler.ServeHTTP(res, req) + res.finalize() + }() - res.finalize() + // userData.server.Handler.ServeHTTP(res, req) + + // res.finalize() } func (srv *Server) handleTask(task *hyper.Task) { taskType := task.Type() + taskData := (*taskData)(task.Userdata()) + fmt.Println("handleTask called") + if taskData != nil { + if taskData.hyperTaskID == taskGetBody { + fmt.Println("taskGetBody called") + if taskData.conn != nil && taskData.conn.bodyWriter != nil { + fmt.Println("taskGetBody calling Close") + taskData.conn.bodyWriter.Close() + } + } else if taskData.hyperTaskID == taskSetBody { + fmt.Println("taskSetBody called") + } + } + if taskType == hyper.TaskError { fmt.Println("hyper task failed with error!") @@ -527,6 +540,20 @@ func freeConnData(userdata c.Pointer) { conn.writeWaker = nil } + if conn.executor != nil { + conn.executor.Free() + conn.executor = nil + } + + if conn.http1Opts != nil { + conn.http1Opts.Free() + conn.http1Opts = nil + } + if conn.http2Opts != nil { + conn.http2Opts.Free() + conn.http2Opts = nil + } + if (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).IsClosing() == 0 { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) } @@ -549,23 +576,16 @@ func (srv *Server) Close() error { defer srv.mu.Unlock() for c := range srv.activeConnections { - if c.executor != nil { - c.executor.Free() - } + c.Close() + delete(srv.activeConnections, c) } srv.uvLoop.Walk(closeWalkCb, nil) - srv.uvLoop.Run(libuv.RUN_DEFAULT) + srv.uvLoop.Run(libuv.RUN_ONCE) + (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).Close(nil) srv.uvLoop.Close() - - if srv.http1Opts != nil { - srv.http1Opts.Free() - } - if srv.http2Opts != nil { - srv.http2Opts.Free() - } return nil } @@ -573,6 +593,42 @@ func (s *Server) shuttingDown() bool { return s.inShutdown.Load() } +func (c *conn) shuttingDown() bool { + return c.isClosing.Load() +} + +func (c *conn) Close() { + c.isClosing.Store(true) + if c.shuttingDown() { + return + } + + if c.readWaker != nil { + c.readWaker.Free() + c.readWaker = nil + } + if c.writeWaker != nil { + c.writeWaker.Free() + c.writeWaker = nil + } + + if c.executor != nil { + c.executor.Free() + c.executor = nil + } + if c.http1Opts != nil { + c.http1Opts.Free() + c.http1Opts = nil + } + if c.http2Opts != nil { + c.http2Opts.Free() + c.http2Opts = nil + } + + (*libuv.Handle)(unsafe.Pointer(&c.pollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&c.stream)).Close(nil) +} + type HandlerFunc func(ResponseWriter, *Request) func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index 21d9b20..6da8bce 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -41,7 +41,6 @@ func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Re } func (mux *ServeMux) Handle(pattern string, handler Handler) { - fmt.Printf("Handle called with pattern: %s\n", pattern) mux.mu.Lock() defer mux.mu.Unlock() @@ -56,4 +55,4 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { } mux.m[pattern] = muxEntry{h: handler, pattern: pattern} -} \ No newline at end of file +}