mirror of
https://github.com/soheilhy/cmux.git
synced 2024-11-14 11:31:28 +08:00
116 lines
2.3 KiB
Go
116 lines
2.3 KiB
Go
|
package cmux
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
defaultBufSize = 1024
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// prefix is the string we look for at the start of a connection
|
||
|
// to check if this connection is using the proxy protocol
|
||
|
prefix = []byte("PROXY ")
|
||
|
prefixLen = len(prefix)
|
||
|
)
|
||
|
|
||
|
func (m *MuxConn) checkPrefix() error {
|
||
|
buf := make([]byte, defaultBufSize)
|
||
|
n, err := m.Read(buf)
|
||
|
|
||
|
reader := bufio.NewReader(bytes.NewReader(buf[:n]))
|
||
|
|
||
|
// Incrementally check each byte of the prefix
|
||
|
for i := 1; i <= prefixLen; i++ {
|
||
|
inp, err := reader.Peek(i)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Check for a prefix mismatch, quit early
|
||
|
if !bytes.Equal(inp, prefix[:i]) {
|
||
|
m.buf.buffer.Write(buf[:n])
|
||
|
m.doneSniffing()
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Read the header line
|
||
|
headerLine, err := reader.ReadString('\n')
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Strip the carriage return and new line
|
||
|
header := headerLine[:len(headerLine)-2]
|
||
|
|
||
|
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||
|
parts := strings.Split(header, " ")
|
||
|
if len(parts) < 2 {
|
||
|
return fmt.Errorf("invalid header line: %s", header)
|
||
|
}
|
||
|
|
||
|
// Verify the type is known
|
||
|
switch parts[1] {
|
||
|
case "UNKNOWN":
|
||
|
return nil
|
||
|
case "TCP4":
|
||
|
case "TCP6":
|
||
|
default:
|
||
|
return fmt.Errorf("unhandled address type: %s", parts[1])
|
||
|
}
|
||
|
|
||
|
if len(parts) != 6 {
|
||
|
return fmt.Errorf("invalid header line: %s", header)
|
||
|
}
|
||
|
|
||
|
// Parse out the source address
|
||
|
ip := net.ParseIP(parts[2])
|
||
|
if ip == nil {
|
||
|
return fmt.Errorf("invalid source ip: %s", parts[2])
|
||
|
}
|
||
|
port, err := strconv.Atoi(parts[4])
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("invalid source port: %s", parts[4])
|
||
|
}
|
||
|
m.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||
|
|
||
|
// Parse out the destination address
|
||
|
ip = net.ParseIP(parts[3])
|
||
|
if ip == nil {
|
||
|
return fmt.Errorf("invalid destination ip: %s", parts[3])
|
||
|
}
|
||
|
port, err = strconv.Atoi(parts[5])
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("invalid destination port: %s", parts[5])
|
||
|
}
|
||
|
m.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||
|
|
||
|
if n != len(headerLine) {
|
||
|
m.buf.buffer.Write(buf[len(headerLine):n])
|
||
|
m.doneSniffing()
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *MuxConn) RemoteAddr() net.Addr {
|
||
|
if m.srcAddr != nil {
|
||
|
return m.srcAddr
|
||
|
}
|
||
|
return m.Conn.RemoteAddr()
|
||
|
}
|
||
|
|
||
|
func (m *MuxConn) LocalAddr() net.Addr {
|
||
|
if m.dstAddr != nil {
|
||
|
return m.dstAddr
|
||
|
}
|
||
|
return m.Conn.LocalAddr()
|
||
|
}
|