2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-11-10 03:31:52 +08:00
cmux/protocol.go
2023-02-18 10:45:29 +08:00

116 lines
2.3 KiB
Go

package cmux
import (
"bufio"
"bytes"
"fmt"
"net"
"strconv"
"strings"
)
const (
defaultBufSize = 1024
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
)
func (m *MuxConn) checkPrefix() error {
buf := make([]byte, defaultBufSize)
n, err := m.Read(buf)
reader := bufio.NewReader(bytes.NewReader(buf[:n]))
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := reader.Peek(i)
if err != nil {
return err
}
// Check for a prefix mismatch, quit early
if !bytes.Equal(inp, prefix[:i]) {
m.buf.buffer.Write(buf[:n])
m.doneSniffing()
return nil
}
}
// Read the header line
headerLine, err := reader.ReadString('\n')
if err != nil {
return err
}
// Strip the carriage return and new line
header := headerLine[:len(headerLine)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) < 2 {
return fmt.Errorf("invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
return nil
case "TCP4":
case "TCP6":
default:
return fmt.Errorf("unhandled address type: %s", parts[1])
}
if len(parts) != 6 {
return fmt.Errorf("invalid header line: %s", header)
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
return fmt.Errorf("invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
return fmt.Errorf("invalid source port: %s", parts[4])
}
m.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
return fmt.Errorf("invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
return fmt.Errorf("invalid destination port: %s", parts[5])
}
m.dstAddr = &net.TCPAddr{IP: ip, Port: port}
if n != len(headerLine) {
m.buf.buffer.Write(buf[len(headerLine):n])
m.doneSniffing()
}
return nil
}
func (m *MuxConn) RemoteAddr() net.Addr {
if m.srcAddr != nil {
return m.srcAddr
}
return m.Conn.RemoteAddr()
}
func (m *MuxConn) LocalAddr() net.Addr {
if m.dstAddr != nil {
return m.dstAddr
}
return m.Conn.LocalAddr()
}