diff --git a/cmux.go b/cmux.go index 5ba921e..af2173f 100644 --- a/cmux.go +++ b/cmux.go @@ -15,6 +15,8 @@ package cmux import ( + "context" + "crypto/tls" "errors" "fmt" "io" @@ -189,7 +191,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { } for _, sl := range m.sls { for _, s := range sl.ss { - matched := s(muc.Conn, muc.startSniffing()) + matched := s(muc.getConn(), muc.startSniffing()) if matched { muc.doneSniffing() if m.readTimeout > noTimeout { @@ -276,7 +278,13 @@ type MuxConn struct { buf bufferedReader } -func newMuxConn(c net.Conn) *MuxConn { +func newMuxConn(c net.Conn) muxConn { + if tlsconn, ok := c.(TLSConn); ok { + return &MuxTLSConn{ + TLSConn: tlsconn, + buf: bufferedReader{source: c}, + } + } return &MuxConn{ Conn: c, buf: bufferedReader{source: c}, @@ -305,3 +313,43 @@ func (m *MuxConn) startSniffing() io.Reader { func (m *MuxConn) doneSniffing() { m.buf.reset(false) } + +func (m *MuxConn) getConn() net.Conn { return m.Conn } + +type muxConn interface { + net.Conn + startSniffing() io.Reader + doneSniffing() + getConn() net.Conn +} + +type TLSConn interface { + net.Conn + CloseWrite() error + ConnectionState() tls.ConnectionState + Handshake() error + HandshakeContext(ctx context.Context) error + NetConn() net.Conn + OCSPResponse() []byte + VerifyHostname(host string) error +} + +var _ TLSConn = (*tls.Conn)(nil) + +type MuxTLSConn struct { + TLSConn + buf bufferedReader +} + +func (m *MuxTLSConn) Read(p []byte) (int, error) { return m.buf.Read(p) } + +func (m *MuxTLSConn) startSniffing() io.Reader { + m.buf.reset(true) + return &m.buf +} + +func (m *MuxTLSConn) doneSniffing() { + m.buf.reset(false) +} + +func (m *MuxTLSConn) getConn() net.Conn { return m.TLSConn }