2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-09-20 02:55:46 +08:00

support proxy protocol

This commit is contained in:
munding 2023-02-18 10:45:29 +08:00
parent 5ec6847320
commit 7303c90c48
2 changed files with 118 additions and 1 deletions

View File

@ -273,7 +273,9 @@ func (l muxListener) Accept() (net.Conn, error) {
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
type MuxConn struct {
net.Conn
buf bufferedReader
buf bufferedReader
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
}
func newMuxConn(c net.Conn) *MuxConn {

115
protocol.go Normal file
View File

@ -0,0 +1,115 @@
package cmux
import (
"bufio"
"bytes"
"fmt"
"net"
"strconv"
"strings"
)
const (
defaultBufSize = 1024
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
)
func (m *MuxConn) checkPrefix() error {
buf := make([]byte, defaultBufSize)
n, err := m.Read(buf)
reader := bufio.NewReader(bytes.NewReader(buf[:n]))
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := reader.Peek(i)
if err != nil {
return err
}
// Check for a prefix mismatch, quit early
if !bytes.Equal(inp, prefix[:i]) {
m.buf.buffer.Write(buf[:n])
m.doneSniffing()
return nil
}
}
// Read the header line
headerLine, err := reader.ReadString('\n')
if err != nil {
return err
}
// Strip the carriage return and new line
header := headerLine[:len(headerLine)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) < 2 {
return fmt.Errorf("invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
return nil
case "TCP4":
case "TCP6":
default:
return fmt.Errorf("unhandled address type: %s", parts[1])
}
if len(parts) != 6 {
return fmt.Errorf("invalid header line: %s", header)
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
return fmt.Errorf("invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
return fmt.Errorf("invalid source port: %s", parts[4])
}
m.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
return fmt.Errorf("invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
return fmt.Errorf("invalid destination port: %s", parts[5])
}
m.dstAddr = &net.TCPAddr{IP: ip, Port: port}
if n != len(headerLine) {
m.buf.buffer.Write(buf[len(headerLine):n])
m.doneSniffing()
}
return nil
}
func (m *MuxConn) RemoteAddr() net.Addr {
if m.srcAddr != nil {
return m.srcAddr
}
return m.Conn.RemoteAddr()
}
func (m *MuxConn) LocalAddr() net.Addr {
if m.dstAddr != nil {
return m.dstAddr
}
return m.Conn.LocalAddr()
}