Skip to content

Latest commit

 

History

History
executable file
·
152 lines (122 loc) · 4.46 KB

04.3.md

File metadata and controls

executable file
·
152 lines (122 loc) · 4.46 KB

04.3 - TCP Proxy

Building a non-TLS terminating TCP proxy is pretty easy. It's very similar to the TCP server we have already created.

We listen for TCP connections. After one is established, we create a new connection to the forwarding IP:port and send all data. Without logging this can be done with a simple io.Copy(connDest, connSrc). With logging we have to use multiple goroutines (as we have seen before).

Only forwardConnection is different. Instead of calling handleConnection we call forwardConnection in a new goroutine.

Inside, we create a TCP connection to server and two channels. Then we handle each side of the connection like the echo TCP server.

// 04.3-01-tcp-proxy.go
package main

import (
    "flag"
    "fmt"
    "io"
    "net"
)

var (
    bindIP, bindPort, destIP, destPort string
)

func init() {
    flag.StringVar(&bindPort, "bindPort", "12345", "bind port")
    flag.StringVar(&bindIP, "bindIP", "127.0.0.1", "bind IP")
    flag.StringVar(&destPort, "destPort", "12345", "bind port")
    flag.StringVar(&destIP, "destIP", "127.0.0.1", "bind IP")
}

// readSocket reads data from socket if available and passes it to channel
func readSocket(conn net.Conn, c chan<- []byte) {

    // Create a buffer to hold data
    buf := make([]byte, 2048)
    // Store remote IP:port for logging
    rAddr := conn.RemoteAddr().String()

    for {
        // Read from connection
        n, err := conn.Read(buf)
        // If connection is closed from the other side
        if err == io.EOF {
            // Close the connction and return
            fmt.Println("Connection closed from", rAddr)
            return
        }
        // For other errors, print the error and return
        if err != nil {
            fmt.Println("Error reading from socket", err)
            return
        }
        // Print data read from socket
        // Note we are only printing and sending the first n bytes.
        // n is the number of bytes read from the connection
        fmt.Printf("Received from %v: %s\n", rAddr, buf[:n])
        // Send data to channel
        c <- buf[:n]
    }
}

// writeSocket reads data from channel and writes it to socket
func writeSocket(conn net.Conn, c <-chan []byte) {

    // Create a buffer to hold data
    buf := make([]byte, 2048)
    // Store remote IP:port for logging
    rAddr := conn.RemoteAddr().String()

    for {
        // Read from channel and copy to buffer
        buf = <-c
        // Write buffer
        n, err := conn.Write(buf)
        // If connection is closed from the other side
        if err == io.EOF {
            // Close the connction and return
            fmt.Println("Connection closed from", rAddr)
            return
        }
        // For other errors, print the error and return
        if err != nil {
            fmt.Println("Error writing to socket", err)
            return
        }
        // Log data sent
        fmt.Printf("Sent to %v: %s\n", rAddr, buf[:n])
    }
}

// forwardConnection creates a connection to the server and then passes packets
func forwardConnection(clientConn net.Conn) {

    // Converting host and port to destIP:destPort
    t := net.JoinHostPort(destIP, destPort)

    // Create a connection to server
    serverConn, err := net.Dial("tcp", t)
    if err != nil {
        fmt.Println(err)
        clientConn.Close()
        return
    }

    // Client to server channel
    c2s := make(chan []byte, 2048)
    // Server to client channel
    s2c := make(chan []byte, 2048)

    go readSocket(clientConn, c2s)
    go writeSocket(serverConn, c2s)
    go readSocket(serverConn, s2c)
    go writeSocket(clientConn, s2c)

}
func main() {

    flag.Parse()

    // Converting host and port to bindIP:bindPort
    t := net.JoinHostPort(bindIP, bindPort)

    // Listen for connections on BindIP:BindPort
    ln, err := net.Listen("tcp", t)
    if err != nil {
        // If we cannot bind, print the error and quit
        panic(err)
    }

    fmt.Printf("Started listening on %v\n", t)

    // Wait for connections
    for {
        // Accept a connection
        conn, err := ln.Accept()
        if err != nil {
            // If there was an error print it and go back to listening
            fmt.Println(err)

            continue
        }
        fmt.Printf("Received connection from %v\n", conn.RemoteAddr().String())

        go forwardConnection(conn)
    }
}