mirror of
https://github.com/soheilhy/cmux.git
synced 2024-11-10 11:41:52 +08:00
support proxy protocol
This commit is contained in:
parent
5ec6847320
commit
7303c90c48
2
cmux.go
2
cmux.go
@ -274,6 +274,8 @@ func (l muxListener) Accept() (net.Conn, error) {
|
|||||||
type MuxConn struct {
|
type MuxConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
buf bufferedReader
|
buf bufferedReader
|
||||||
|
dstAddr *net.TCPAddr
|
||||||
|
srcAddr *net.TCPAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMuxConn(c net.Conn) *MuxConn {
|
func newMuxConn(c net.Conn) *MuxConn {
|
||||||
|
115
protocol.go
Normal file
115
protocol.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package cmux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultBufSize = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// prefix is the string we look for at the start of a connection
|
||||||
|
// to check if this connection is using the proxy protocol
|
||||||
|
prefix = []byte("PROXY ")
|
||||||
|
prefixLen = len(prefix)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *MuxConn) checkPrefix() error {
|
||||||
|
buf := make([]byte, defaultBufSize)
|
||||||
|
n, err := m.Read(buf)
|
||||||
|
|
||||||
|
reader := bufio.NewReader(bytes.NewReader(buf[:n]))
|
||||||
|
|
||||||
|
// Incrementally check each byte of the prefix
|
||||||
|
for i := 1; i <= prefixLen; i++ {
|
||||||
|
inp, err := reader.Peek(i)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for a prefix mismatch, quit early
|
||||||
|
if !bytes.Equal(inp, prefix[:i]) {
|
||||||
|
m.buf.buffer.Write(buf[:n])
|
||||||
|
m.doneSniffing()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the header line
|
||||||
|
headerLine, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the carriage return and new line
|
||||||
|
header := headerLine[:len(headerLine)-2]
|
||||||
|
|
||||||
|
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||||||
|
parts := strings.Split(header, " ")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return fmt.Errorf("invalid header line: %s", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the type is known
|
||||||
|
switch parts[1] {
|
||||||
|
case "UNKNOWN":
|
||||||
|
return nil
|
||||||
|
case "TCP4":
|
||||||
|
case "TCP6":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unhandled address type: %s", parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) != 6 {
|
||||||
|
return fmt.Errorf("invalid header line: %s", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse out the source address
|
||||||
|
ip := net.ParseIP(parts[2])
|
||||||
|
if ip == nil {
|
||||||
|
return fmt.Errorf("invalid source ip: %s", parts[2])
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid source port: %s", parts[4])
|
||||||
|
}
|
||||||
|
m.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||||
|
|
||||||
|
// Parse out the destination address
|
||||||
|
ip = net.ParseIP(parts[3])
|
||||||
|
if ip == nil {
|
||||||
|
return fmt.Errorf("invalid destination ip: %s", parts[3])
|
||||||
|
}
|
||||||
|
port, err = strconv.Atoi(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination port: %s", parts[5])
|
||||||
|
}
|
||||||
|
m.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||||
|
|
||||||
|
if n != len(headerLine) {
|
||||||
|
m.buf.buffer.Write(buf[len(headerLine):n])
|
||||||
|
m.doneSniffing()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MuxConn) RemoteAddr() net.Addr {
|
||||||
|
if m.srcAddr != nil {
|
||||||
|
return m.srcAddr
|
||||||
|
}
|
||||||
|
return m.Conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MuxConn) LocalAddr() net.Addr {
|
||||||
|
if m.dstAddr != nil {
|
||||||
|
return m.dstAddr
|
||||||
|
}
|
||||||
|
return m.Conn.LocalAddr()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user