diff --git a/cmux.go b/cmux.go index 5ba921e..79e491e 100644 --- a/cmux.go +++ b/cmux.go @@ -273,7 +273,9 @@ func (l muxListener) Accept() (net.Conn, error) { // MuxConn wraps a net.Conn and provides transparent sniffing of connection data. type MuxConn struct { net.Conn - buf bufferedReader + buf bufferedReader + dstAddr *net.TCPAddr + srcAddr *net.TCPAddr } func newMuxConn(c net.Conn) *MuxConn { diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..1c4e835 --- /dev/null +++ b/protocol.go @@ -0,0 +1,115 @@ +package cmux + +import ( + "bufio" + "bytes" + "fmt" + "net" + "strconv" + "strings" +) + +const ( + defaultBufSize = 1024 +) + +var ( + // prefix is the string we look for at the start of a connection + // to check if this connection is using the proxy protocol + prefix = []byte("PROXY ") + prefixLen = len(prefix) +) + +func (m *MuxConn) checkPrefix() error { + buf := make([]byte, defaultBufSize) + n, err := m.Read(buf) + + reader := bufio.NewReader(bytes.NewReader(buf[:n])) + + // Incrementally check each byte of the prefix + for i := 1; i <= prefixLen; i++ { + inp, err := reader.Peek(i) + if err != nil { + return err + } + + // Check for a prefix mismatch, quit early + if !bytes.Equal(inp, prefix[:i]) { + m.buf.buffer.Write(buf[:n]) + m.doneSniffing() + return nil + } + } + + // Read the header line + headerLine, err := reader.ReadString('\n') + if err != nil { + return err + } + + // Strip the carriage return and new line + header := headerLine[:len(headerLine)-2] + + // Split on spaces, should be (PROXY ) + parts := strings.Split(header, " ") + if len(parts) < 2 { + return fmt.Errorf("invalid header line: %s", header) + } + + // Verify the type is known + switch parts[1] { + case "UNKNOWN": + return nil + case "TCP4": + case "TCP6": + default: + return fmt.Errorf("unhandled address type: %s", parts[1]) + } + + if len(parts) != 6 { + return fmt.Errorf("invalid header line: %s", header) + } + + // Parse out the source address + ip := net.ParseIP(parts[2]) + if ip == nil { + return fmt.Errorf("invalid source ip: %s", parts[2]) + } + port, err := strconv.Atoi(parts[4]) + if err != nil { + return fmt.Errorf("invalid source port: %s", parts[4]) + } + m.srcAddr = &net.TCPAddr{IP: ip, Port: port} + + // Parse out the destination address + ip = net.ParseIP(parts[3]) + if ip == nil { + return fmt.Errorf("invalid destination ip: %s", parts[3]) + } + port, err = strconv.Atoi(parts[5]) + if err != nil { + return fmt.Errorf("invalid destination port: %s", parts[5]) + } + m.dstAddr = &net.TCPAddr{IP: ip, Port: port} + + if n != len(headerLine) { + m.buf.buffer.Write(buf[len(headerLine):n]) + m.doneSniffing() + } + + return nil +} + +func (m *MuxConn) RemoteAddr() net.Addr { + if m.srcAddr != nil { + return m.srcAddr + } + return m.Conn.RemoteAddr() +} + +func (m *MuxConn) LocalAddr() net.Addr { + if m.dstAddr != nil { + return m.dstAddr + } + return m.Conn.LocalAddr() +}