mirror of
https://github.com/soheilhy/cmux.git
synced 2024-11-10 11:41:52 +08:00
support proxy protocol
This commit is contained in:
parent
5ec6847320
commit
7303c90c48
4
cmux.go
4
cmux.go
@ -273,7 +273,9 @@ func (l muxListener) Accept() (net.Conn, error) {
|
||||
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
|
||||
type MuxConn struct {
|
||||
net.Conn
|
||||
buf bufferedReader
|
||||
buf bufferedReader
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
}
|
||||
|
||||
func newMuxConn(c net.Conn) *MuxConn {
|
||||
|
115
protocol.go
Normal file
115
protocol.go
Normal file
@ -0,0 +1,115 @@
|
||||
package cmux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBufSize = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
// prefix is the string we look for at the start of a connection
|
||||
// to check if this connection is using the proxy protocol
|
||||
prefix = []byte("PROXY ")
|
||||
prefixLen = len(prefix)
|
||||
)
|
||||
|
||||
func (m *MuxConn) checkPrefix() error {
|
||||
buf := make([]byte, defaultBufSize)
|
||||
n, err := m.Read(buf)
|
||||
|
||||
reader := bufio.NewReader(bytes.NewReader(buf[:n]))
|
||||
|
||||
// Incrementally check each byte of the prefix
|
||||
for i := 1; i <= prefixLen; i++ {
|
||||
inp, err := reader.Peek(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for a prefix mismatch, quit early
|
||||
if !bytes.Equal(inp, prefix[:i]) {
|
||||
m.buf.buffer.Write(buf[:n])
|
||||
m.doneSniffing()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Read the header line
|
||||
headerLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Strip the carriage return and new line
|
||||
header := headerLine[:len(headerLine)-2]
|
||||
|
||||
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||||
parts := strings.Split(header, " ")
|
||||
if len(parts) < 2 {
|
||||
return fmt.Errorf("invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Verify the type is known
|
||||
switch parts[1] {
|
||||
case "UNKNOWN":
|
||||
return nil
|
||||
case "TCP4":
|
||||
case "TCP6":
|
||||
default:
|
||||
return fmt.Errorf("unhandled address type: %s", parts[1])
|
||||
}
|
||||
|
||||
if len(parts) != 6 {
|
||||
return fmt.Errorf("invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Parse out the source address
|
||||
ip := net.ParseIP(parts[2])
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid source ip: %s", parts[2])
|
||||
}
|
||||
port, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid source port: %s", parts[4])
|
||||
}
|
||||
m.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
// Parse out the destination address
|
||||
ip = net.ParseIP(parts[3])
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid destination ip: %s", parts[3])
|
||||
}
|
||||
port, err = strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination port: %s", parts[5])
|
||||
}
|
||||
m.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
if n != len(headerLine) {
|
||||
m.buf.buffer.Write(buf[len(headerLine):n])
|
||||
m.doneSniffing()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MuxConn) RemoteAddr() net.Addr {
|
||||
if m.srcAddr != nil {
|
||||
return m.srcAddr
|
||||
}
|
||||
return m.Conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (m *MuxConn) LocalAddr() net.Addr {
|
||||
if m.dstAddr != nil {
|
||||
return m.dstAddr
|
||||
}
|
||||
return m.Conn.LocalAddr()
|
||||
}
|
Loading…
Reference in New Issue
Block a user