mirror of
https://github.com/soheilhy/cmux.git
synced 2025-09-17 12:10:08 +08:00
Initial commit
This commit is contained in:
197
cmux.go
Normal file
197
cmux.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package cmux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Matcher matches a connection based on its content.
|
||||
type Matcher func(r io.Reader) (ok bool)
|
||||
|
||||
// ErrorHandler handles an error and returns whether
|
||||
// the mux should continue serving the listener.
|
||||
type ErrorHandler func(err error) (ok bool)
|
||||
|
||||
// ErrNotMatched is returned whenever a connection is not matched by any of
|
||||
// the matchers registered in the multiplexer.
|
||||
type ErrNotMatched struct {
|
||||
c net.Conn
|
||||
}
|
||||
|
||||
func (e ErrNotMatched) Error() string {
|
||||
return fmt.Sprintf("mux: connection %v not matched by an matcher",
|
||||
e.c.RemoteAddr())
|
||||
}
|
||||
|
||||
func (e ErrNotMatched) Temporary() bool { return true }
|
||||
func (e ErrNotMatched) Timeout() bool { return false }
|
||||
|
||||
type errListenerClosed string
|
||||
|
||||
func (e errListenerClosed) Error() string { return string(e) }
|
||||
func (e errListenerClosed) Temporary() bool { return false }
|
||||
func (e errListenerClosed) Timeout() bool { return false }
|
||||
|
||||
var (
|
||||
ErrListenerClosed = errListenerClosed("mux: listener closed")
|
||||
)
|
||||
|
||||
// New instantiates a new connection multiplexer.
|
||||
func New(l net.Listener) CMux {
|
||||
if !flag.Parsed() {
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
return &cMux{
|
||||
root: l,
|
||||
bufLen: 1024,
|
||||
errh: func(err error) bool { return true },
|
||||
}
|
||||
}
|
||||
|
||||
// CMux is a multiplexer for network connections.
|
||||
type CMux interface {
|
||||
// Match returns a net.Listener that sees (i.e., accepts) only
|
||||
// the connections matched by at least one of the matcher.
|
||||
//
|
||||
// The order used to call Match determines the priority of matchers.
|
||||
Match(matchers ...Matcher) net.Listener
|
||||
// Serve starts multiplexing the listener. Serve blocks and perhaps
|
||||
// should be invoked concurrently within a go routine.
|
||||
Serve() error
|
||||
// HandleError registers an error handler that handles listener errors.
|
||||
HandleError(h ErrorHandler)
|
||||
}
|
||||
|
||||
type matchersListener struct {
|
||||
ss []Matcher
|
||||
l muxListener
|
||||
}
|
||||
|
||||
type cMux struct {
|
||||
root net.Listener
|
||||
bufLen int
|
||||
errh ErrorHandler
|
||||
sls []matchersListener
|
||||
}
|
||||
|
||||
func (m *cMux) Match(matchers ...Matcher) (l net.Listener) {
|
||||
ml := muxListener{
|
||||
Listener: m.root,
|
||||
cch: make(chan net.Conn, m.bufLen),
|
||||
}
|
||||
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
|
||||
return ml
|
||||
}
|
||||
|
||||
func (m *cMux) Serve() error {
|
||||
defer func() {
|
||||
for _, sl := range m.sls {
|
||||
close(sl.l.cch)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
c, err := m.root.Accept()
|
||||
if err != nil {
|
||||
if !m.handleErr(err) {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
muc := newMuxConn(c)
|
||||
matched := false
|
||||
outer:
|
||||
for _, sl := range m.sls {
|
||||
for _, s := range sl.ss {
|
||||
matched = s(muc.sniffer())
|
||||
muc.reset()
|
||||
if matched {
|
||||
sl.l.cch <- muc
|
||||
break outer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
c.Close()
|
||||
err := ErrNotMatched{c: c}
|
||||
if !m.handleErr(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cMux) HandleError(h ErrorHandler) {
|
||||
m.errh = h
|
||||
}
|
||||
|
||||
func (m *cMux) handleErr(err error) bool {
|
||||
if !m.errh(err) {
|
||||
return false
|
||||
}
|
||||
|
||||
if ne, ok := err.(net.Error); ok {
|
||||
return ne.Temporary()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type muxListener struct {
|
||||
net.Listener
|
||||
cch chan net.Conn
|
||||
}
|
||||
|
||||
func (l muxListener) Accept() (c net.Conn, err error) {
|
||||
c, ok := <-l.cch
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type MuxConn struct {
|
||||
net.Conn
|
||||
prv *bytes.Buffer
|
||||
nxt *bytes.Buffer
|
||||
}
|
||||
|
||||
func newMuxConn(c net.Conn) *MuxConn {
|
||||
return &MuxConn{
|
||||
Conn: c,
|
||||
prv: &bytes.Buffer{},
|
||||
nxt: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MuxConn) Read(b []byte) (n int, err error) {
|
||||
if n, err = m.prv.Read(b); err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err = m.Conn.Read(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MuxConn) sniffer() io.Reader {
|
||||
return io.MultiReader(io.TeeReader(m.prv, m.nxt), io.TeeReader(m.Conn, m.nxt))
|
||||
}
|
||||
|
||||
func (m *MuxConn) reset() {
|
||||
if m.nxt.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if m.prv.Len() != 0 {
|
||||
io.Copy(m.nxt, m.prv)
|
||||
}
|
||||
|
||||
m.prv, m.nxt = m.nxt, m.prv
|
||||
m.nxt.Reset()
|
||||
}
|
Reference in New Issue
Block a user