2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-11-14 11:31:28 +08:00

feat: mux conn support tls conn method

This commit is contained in:
luke 2022-09-28 01:27:53 +08:00
parent 5ec6847320
commit 177106da34

52
cmux.go
View File

@ -15,6 +15,8 @@
package cmux package cmux
import ( import (
"context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -189,7 +191,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
} }
for _, sl := range m.sls { for _, sl := range m.sls {
for _, s := range sl.ss { for _, s := range sl.ss {
matched := s(muc.Conn, muc.startSniffing()) matched := s(muc.getConn(), muc.startSniffing())
if matched { if matched {
muc.doneSniffing() muc.doneSniffing()
if m.readTimeout > noTimeout { if m.readTimeout > noTimeout {
@ -276,7 +278,13 @@ type MuxConn struct {
buf bufferedReader buf bufferedReader
} }
func newMuxConn(c net.Conn) *MuxConn { func newMuxConn(c net.Conn) muxConn {
if tlsconn, ok := c.(TLSConn); ok {
return &MuxTLSConn{
TLSConn: tlsconn,
buf: bufferedReader{source: c},
}
}
return &MuxConn{ return &MuxConn{
Conn: c, Conn: c,
buf: bufferedReader{source: c}, buf: bufferedReader{source: c},
@ -305,3 +313,43 @@ func (m *MuxConn) startSniffing() io.Reader {
func (m *MuxConn) doneSniffing() { func (m *MuxConn) doneSniffing() {
m.buf.reset(false) m.buf.reset(false)
} }
func (m *MuxConn) getConn() net.Conn { return m.Conn }
type muxConn interface {
net.Conn
startSniffing() io.Reader
doneSniffing()
getConn() net.Conn
}
type TLSConn interface {
net.Conn
CloseWrite() error
ConnectionState() tls.ConnectionState
Handshake() error
HandshakeContext(ctx context.Context) error
NetConn() net.Conn
OCSPResponse() []byte
VerifyHostname(host string) error
}
var _ TLSConn = (*tls.Conn)(nil)
type MuxTLSConn struct {
TLSConn
buf bufferedReader
}
func (m *MuxTLSConn) Read(p []byte) (int, error) { return m.buf.Read(p) }
func (m *MuxTLSConn) startSniffing() io.Reader {
m.buf.reset(true)
return &m.buf
}
func (m *MuxTLSConn) doneSniffing() {
m.buf.reset(false)
}
func (m *MuxTLSConn) getConn() net.Conn { return m.TLSConn }