diff --git a/common/ciphers/ssaead/cipher_packet.go b/common/ciphers/ssaead/cipher_packet.go index f9474d2..d20cd13 100644 --- a/common/ciphers/ssaead/cipher_packet.go +++ b/common/ciphers/ssaead/cipher_packet.go @@ -38,7 +38,7 @@ func GetAEADPacketCiphers(method string) func(string, net.PacketConn) (net.Packe PacketConn: packCon, IAEADCipher: c, key: evpBytesToKey(password, c.KeySize()), - buf: pool.GetUdpBuf(), + buf: pool.GetBuf(), } return ap, nil } diff --git a/common/ciphers/ssstream/cipher_packet.go b/common/ciphers/ssstream/cipher_packet.go index d841fbd..26e3423 100644 --- a/common/ciphers/ssstream/cipher_packet.go +++ b/common/ciphers/ssstream/cipher_packet.go @@ -86,13 +86,13 @@ func (c *streamPacket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { return n, addr, err } - pool.GetUdpBuf() + pool.GetBuf() decryptr.XORKeyStream(b[ivLen:], b[ivLen:ivLen+n]) copy(b, b[ivLen:]) return n - ivLen, addr, err } func (c *streamPacket) Close() error { - pool.PutUdpBuf(c.buf) + pool.PutBuf(c.buf) return c.PacketConn.Close() } diff --git a/common/pool/pool.go b/common/pool/pool.go index e9c2b81..7ca5430 100644 --- a/common/pool/pool.go +++ b/common/pool/pool.go @@ -2,7 +2,7 @@ package pool import "sync" -const BufferSize = 4108 +const BufferSize = 4096 var ( poolMap map[int]*sync.Pool diff --git a/common/pool/udp_buffer_pool.go b/common/pool/udp_buffer_pool.go deleted file mode 100644 index 2bb3de8..0000000 --- a/common/pool/udp_buffer_pool.go +++ /dev/null @@ -1,25 +0,0 @@ -package pool - -import "sync" - -var udpPool *sync.Pool - -const MAX_UDP_BUF_SIZE int = 65507 - -func init() { - udpPool = &sync.Pool{ - New: func() interface{} { - return make([]byte, MAX_UDP_BUF_SIZE) - }, - } -} - -func GetUdpBuf() []byte { - buf := udpPool.Get().([]byte) - buf = buf[:cap(buf)] - return buf -} - -func PutUdpBuf(buf []byte) { - udpPool.Put(buf) -} diff --git a/proxy/server/shadowsocks.go b/proxy/server/shadowsocks.go index 8fcfe58..ec0c1a1 100644 --- a/proxy/server/shadowsocks.go +++ b/proxy/server/shadowsocks.go @@ -25,12 +25,6 @@ import ( type mode int -const ( - remoteServer mode = iota - relayClient - socksClient -) - var ( logging *log.Logging // ShadowsocksServerList Global map for store shadowsocks proxy @@ -268,8 +262,18 @@ func relayTCP(left, right net.Conn) (int64, int64, error) { Err error } ch := make(chan res) + defer func() { + if e := recover(); e != nil { + log.Error("panic in timedCopy: %v", e) + } + }() go func() { + defer func() { + if e := recover(); e != nil { + log.Error("panic in timedCopy: %v", e) + } + }() n, err := io.Copy(right, left) right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left @@ -332,8 +336,8 @@ func (s *ShadowsocksProxy) startUDP() error { } nm := newNATmap(s.ConnectTimeout) - buf := pool.GetUdpBuf() - defer pool.PutUdpBuf(buf) + buf := pool.GetBuf() + defer pool.PutBuf(buf) // logging.Info("listening UDP on %s", addr) @@ -384,7 +388,7 @@ func (s *ShadowsocksProxy) startUDP() error { continue } - nm.Add(raddr, server, pc, remoteServer) + nm.Add(raddr, server, pc) } _, err = pc.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature @@ -457,7 +461,6 @@ func (m *natmap) Get(key string) net.PacketConn { func (m *natmap) Set(key string, pc net.PacketConn) { m.Lock() defer m.Unlock() - m.m[key] = pc } @@ -473,11 +476,10 @@ func (m *natmap) Del(key string) net.PacketConn { return nil } -func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, role mode) { +func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn) { m.Set(peer.String(), src) - go func() { - timedCopy(dst, peer, src, m.timeout, role) + timedCopy(dst, peer, src, m.timeout) if pc := m.Del(peer.String()); pc != nil { pc.Close() } @@ -485,9 +487,14 @@ func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, role mode) { } // copy from src to dst at target with read timeout -func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration, role mode) error { - buf := pool.GetUdpBuf() - defer pool.PutUdpBuf(buf) +func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration) error { + buf := pool.GetBuf() + defer pool.PutBuf(buf) + defer func() { + if e := recover(); e != nil { + log.Error("panic in timedCopy: %v", e) + } + }() for { src.SetReadDeadline(time.Now().Add(timeout)) @@ -496,20 +503,11 @@ func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout return errors.Cause(err) } - switch role { - case remoteServer: // server -> client: add original packet source - srcAddr := socks.ParseAddr(raddr.String()) - srcAddrByte := srcAddr.Raw - copy(buf[len(srcAddrByte):], buf[:n]) - copy(buf, srcAddrByte) - _, err = dst.WriteTo(buf[:len(srcAddrByte)+n], target) - case relayClient: // client -> user: strip original packet source - srcAddr := socks.SplitAddr(buf[:n]) - srcAddrByte := srcAddr.Raw - _, err = dst.WriteTo(buf[len(srcAddrByte):n], target) - case socksClient: // client -> socks5 program: just set RSV and FRAG = 0 - _, err = dst.WriteTo(append([]byte{0, 0, 0}, buf[:n]...), target) - } + srcAddr := socks.ParseAddr(raddr.String()) + srcAddrByte := srcAddr.Raw + copy(buf[len(srcAddrByte):], buf[:n]) + copy(buf, srcAddrByte) + _, err = dst.WriteTo(buf[:len(srcAddrByte)+n], target) if err != nil { return errors.Cause(err)