mirror of
https://github.com/soheilhy/cmux.git
synced 2025-10-17 04:43:12 +08:00
Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
711042c095 | ||
|
ac00452023 | ||
|
bb79a83465 | ||
|
7e08502c7a | ||
|
0c129dc694 | ||
|
34a8ab6cda | ||
|
3b204bab2a | ||
|
9a3402ad7a | ||
|
4f90533583 | ||
|
8cd60510aa | ||
|
f671b41193 | ||
|
885b8d8a14 | ||
|
0068a46c9c | ||
|
6a5d332559 | ||
|
c0f3570a02 | ||
|
b6ec57c1a4 | ||
|
79b9df6ccf | ||
|
210139db95 | ||
|
bf4a8ede9e | ||
|
f661dcfb59 | ||
|
526b64db7a | ||
|
3ac8d3a667 | ||
|
861c99e0fc | ||
|
13f520d62c | ||
|
e132036cce |
@@ -1,8 +1,9 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.5
|
||||
- 1.6
|
||||
- 1.7
|
||||
- 1.8
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
# Auto-generated with:
|
||||
# git log --oneline --pretty=format:'%an <%aE>' | sort -u
|
||||
#
|
||||
Andreas Jaekle <andreas@jaekle.net>
|
||||
Dmitri Shuralyov <shurcooL@gmail.com>
|
||||
Ethan Mosbaugh <emosbaugh@gmail.com>
|
||||
Soheil Hassas Yeganeh <soheil.h.y@gmail.com>
|
||||
|
@@ -32,7 +32,7 @@ httpS := &http.Server{
|
||||
}
|
||||
|
||||
trpcS := rpc.NewServer()
|
||||
s.Register(&ExampleRPCRcvr{})
|
||||
trpcS.Register(&ExampleRPCRcvr{})
|
||||
|
||||
// Use the muxed listeners for your servers.
|
||||
go grpcS.Serve(grpcL)
|
||||
|
@@ -20,6 +20,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
@@ -43,6 +44,10 @@ func (c *mockConn) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *mockConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func discard(l net.Listener) {
|
||||
for {
|
||||
if _, err := l.Accept(); err != nil {
|
||||
|
@@ -42,6 +42,10 @@ func (s *bufferedReader) Read(p []byte) (int, error) {
|
||||
bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize])
|
||||
s.bufferRead += bn
|
||||
return bn, s.lastErr
|
||||
} else if !s.sniffing && s.buffer.Cap() != 0 {
|
||||
// We don't need the buffer anymore.
|
||||
// Reset it to release the internal slice.
|
||||
s.buffer = bytes.Buffer{}
|
||||
}
|
||||
|
||||
// If there is nothing more to return in the sniffed buffer, read from the
|
||||
|
36
cmux.go
36
cmux.go
@@ -19,6 +19,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Matcher matches a connection based on its content.
|
||||
@@ -60,13 +61,17 @@ func (e errListenerClosed) Timeout() bool { return false }
|
||||
// listener is closed.
|
||||
var ErrListenerClosed = errListenerClosed("mux: listener closed")
|
||||
|
||||
// for readability of readTimeout
|
||||
var noTimeout time.Duration
|
||||
|
||||
// New instantiates a new connection multiplexer.
|
||||
func New(l net.Listener) CMux {
|
||||
return &cMux{
|
||||
root: l,
|
||||
bufLen: 1024,
|
||||
errh: func(_ error) bool { return true },
|
||||
donec: make(chan struct{}),
|
||||
root: l,
|
||||
bufLen: 1024,
|
||||
errh: func(_ error) bool { return true },
|
||||
donec: make(chan struct{}),
|
||||
readTimeout: noTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +95,8 @@ type CMux interface {
|
||||
Serve() error
|
||||
// HandleError registers an error handler that handles listener errors.
|
||||
HandleError(ErrorHandler)
|
||||
// sets a timeout for the read of matchers
|
||||
SetReadTimeout(time.Duration)
|
||||
}
|
||||
|
||||
type matchersListener struct {
|
||||
@@ -98,11 +105,12 @@ type matchersListener struct {
|
||||
}
|
||||
|
||||
type cMux struct {
|
||||
root net.Listener
|
||||
bufLen int
|
||||
errh ErrorHandler
|
||||
donec chan struct{}
|
||||
sls []matchersListener
|
||||
root net.Listener
|
||||
bufLen int
|
||||
errh ErrorHandler
|
||||
donec chan struct{}
|
||||
sls []matchersListener
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
|
||||
@@ -129,6 +137,10 @@ func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
|
||||
return ml
|
||||
}
|
||||
|
||||
func (m *cMux) SetReadTimeout(t time.Duration) {
|
||||
m.readTimeout = t
|
||||
}
|
||||
|
||||
func (m *cMux) Serve() error {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@@ -163,11 +175,17 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
muc := newMuxConn(c)
|
||||
if m.readTimeout > noTimeout {
|
||||
_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
|
||||
}
|
||||
for _, sl := range m.sls {
|
||||
for _, s := range sl.ss {
|
||||
matched := s(muc.Conn, muc.startSniffing())
|
||||
if matched {
|
||||
muc.doneSniffing()
|
||||
if m.readTimeout > noTimeout {
|
||||
_ = c.SetReadDeadline(time.Time{})
|
||||
}
|
||||
select {
|
||||
case sl.l.connc <- muc:
|
||||
case <-donec:
|
||||
|
261
cmux_test.go
261
cmux_test.go
@@ -15,13 +15,20 @@
|
||||
package cmux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -31,6 +38,7 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
func safeServe(errCh chan<- error, muxl CMux) {
|
||||
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed") {
|
||||
errCh <- err
|
||||
}
|
||||
}
|
||||
@@ -73,14 +81,17 @@ func (l *chanListener) Accept() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func testListener(t *testing.T) (net.Listener, func()) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var once sync.Once
|
||||
return l, func() {
|
||||
if err := l.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
once.Do(func() {
|
||||
if err := l.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,8 +133,57 @@ func runTestHTTPServer(errCh chan<- error, l net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateTLSCert(t *testing.T) {
|
||||
err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTLSCert(t *testing.T) {
|
||||
err := os.Remove("cert.pem")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = os.Remove("key.pem")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func runTestTLSServer(errCh chan<- error, l net.Listener) {
|
||||
certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
log.Printf("1")
|
||||
return
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
Rand: rand.Reader,
|
||||
}
|
||||
|
||||
tlsl := tls.NewListener(l, config)
|
||||
runTestHTTPServer(errCh, tlsl)
|
||||
}
|
||||
|
||||
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
|
||||
r, err := http.Get("http://" + addr.String())
|
||||
runTestHTTPClient(t, "http", addr)
|
||||
}
|
||||
|
||||
func runTestTLSClient(t *testing.T, addr net.Addr) {
|
||||
runTestHTTPClient(t, "https", addr)
|
||||
}
|
||||
|
||||
func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) {
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
r, err := client.Get(proto + "://" + addr.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -181,6 +241,84 @@ func runTestRPCClient(t *testing.T, addr net.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
handleHTTP1Close = 1
|
||||
handleHTTP1Request = 2
|
||||
handleAnyClose = 3
|
||||
handleAnyRequest = 4
|
||||
)
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
lis, Close := testListener(t)
|
||||
defer Close()
|
||||
result := make(chan int, 5)
|
||||
testDuration := time.Millisecond * 500
|
||||
m := New(lis)
|
||||
m.SetReadTimeout(testDuration)
|
||||
http1 := m.Match(HTTP1Fast())
|
||||
any := m.Match(Any())
|
||||
go func() {
|
||||
_ = m.Serve()
|
||||
}()
|
||||
go func() {
|
||||
con, err := http1.Accept()
|
||||
if err != nil {
|
||||
result <- handleHTTP1Close
|
||||
} else {
|
||||
_, _ = con.Write([]byte("http1"))
|
||||
_ = con.Close()
|
||||
result <- handleHTTP1Request
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
con, err := any.Accept()
|
||||
if err != nil {
|
||||
result <- handleAnyClose
|
||||
} else {
|
||||
_, _ = con.Write([]byte("any"))
|
||||
_ = con.Close()
|
||||
result <- handleAnyRequest
|
||||
}
|
||||
}()
|
||||
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
|
||||
client, err := net.Dial("tcp", lis.Addr().String())
|
||||
if err != nil {
|
||||
log.Fatal("testTimeout client failed: ", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
}()
|
||||
time.Sleep(testDuration / 2)
|
||||
if len(result) != 0 {
|
||||
log.Print("tcp ")
|
||||
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
|
||||
}
|
||||
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
|
||||
buffer := make([]byte, 10)
|
||||
rl, err := client.Read(buffer)
|
||||
if err != nil {
|
||||
t.Fatal("testTimeout failed: client error: ", err, rl)
|
||||
}
|
||||
Close()
|
||||
if rl != 3 {
|
||||
log.Print("testTimeout failed: response from wrong sevice ", rl)
|
||||
}
|
||||
if string(buffer[0:3]) != "any" {
|
||||
log.Print("testTimeout failed: response from wrong sevice ")
|
||||
}
|
||||
time.Sleep(testDuration * 2)
|
||||
if len(result) != 2 {
|
||||
t.Fatal("testTimeout failed: accepted to less: ", len(result))
|
||||
}
|
||||
if a := <-result; a != handleAnyRequest {
|
||||
t.Fatal("testTimeout failed: any rule did not match")
|
||||
}
|
||||
if a := <-result; a != handleHTTP1Close {
|
||||
t.Fatal("testTimeout failed: no close an http rule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
@@ -261,6 +399,33 @@ func TestAny(t *testing.T) {
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
generateTLSCert(t)
|
||||
defer cleanupTLSCert(t)
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}()
|
||||
l, cleanup := testListener(t)
|
||||
defer cleanup()
|
||||
|
||||
muxl := New(l)
|
||||
tlsl := muxl.Match(TLS())
|
||||
httpl := muxl.Match(Any())
|
||||
|
||||
go runTestTLSServer(errCh, tlsl)
|
||||
go runTestHTTPServer(errCh, httpl)
|
||||
go safeServe(errCh, muxl)
|
||||
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
runTestTLSClient(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestHTTP2(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
@@ -312,6 +477,85 @@ func TestHTTP2(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2MatchHeaderField(t *testing.T) {
|
||||
testHTTP2MatchHeaderField(t, HTTP2HeaderField, "value", "value", "anothervalue")
|
||||
}
|
||||
|
||||
func TestHTTP2MatchHeaderFieldPrefix(t *testing.T) {
|
||||
testHTTP2MatchHeaderField(t, HTTP2HeaderFieldPrefix, "application/grpc+proto", "application/grpc", "application/json")
|
||||
}
|
||||
|
||||
func testHTTP2MatchHeaderField(
|
||||
t *testing.T,
|
||||
matcherConstructor func(string, string) Matcher,
|
||||
headerValue string,
|
||||
matchValue string,
|
||||
notMatchValue string,
|
||||
) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}()
|
||||
name := "name"
|
||||
writer, reader := net.Pipe()
|
||||
go func() {
|
||||
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
enc := hpack.NewEncoder(&buf)
|
||||
if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
framer := http2.NewFramer(writer, nil)
|
||||
err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 1,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
l := newChanListener()
|
||||
l.connCh <- reader
|
||||
muxl := New(l)
|
||||
// Register a bogus matcher that only reads one byte.
|
||||
muxl.Match(func(r io.Reader) bool {
|
||||
var b [1]byte
|
||||
_, _ = r.Read(b[:])
|
||||
return false
|
||||
})
|
||||
// Create a matcher that cannot match the response.
|
||||
muxl.Match(matcherConstructor(name, notMatchValue))
|
||||
// Then match with the expected field.
|
||||
h2l := muxl.Match(matcherConstructor(name, matchValue))
|
||||
go safeServe(errCh, muxl)
|
||||
muxedConn, err := h2l.Accept()
|
||||
close(l.connCh)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var b [len(http2.ClientPreface)]byte
|
||||
// We have the sniffed buffer first...
|
||||
if _, err := muxedConn.Read(b[:]); err == io.EOF {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b[:]) != http2.ClientPreface {
|
||||
t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPGoRPC(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
@@ -415,7 +659,9 @@ func TestClose(t *testing.T) {
|
||||
if err != ErrListenerClosed {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := c2.Read([]byte{}); err != io.ErrClosedPipe {
|
||||
// The error is either io.ErrClosedPipe or net.OpError wrapping
|
||||
// a net.pipeError depending on the go version.
|
||||
if _, err := c2.Read([]byte{}); !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("connection is not closed and is leaked: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -439,6 +685,7 @@ func interestingGoroutines() (gs []string) {
|
||||
}
|
||||
|
||||
if stack == "" ||
|
||||
strings.Contains(stack, "main.main()") ||
|
||||
strings.Contains(stack, "testing.Main(") ||
|
||||
strings.Contains(stack, "runtime.goexit") ||
|
||||
strings.Contains(stack, "created by runtime.gc") ||
|
||||
|
@@ -112,7 +112,7 @@ func Example() {
|
||||
|
||||
// We first match the connection against HTTP2 fields. If matched, the
|
||||
// connection will be sent through the "grpcl" listener.
|
||||
grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
|
||||
grpcl := m.Match(cmux.HTTP2HeaderFieldPrefix("content-type", "application/grpc"))
|
||||
//Otherwise, we match it againts a websocket upgrade request.
|
||||
wsl := m.Match(cmux.HTTP1HeaderField("Upgrade", "websocket"))
|
||||
|
||||
|
123
matchers.go
123
matchers.go
@@ -16,6 +16,7 @@ package cmux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -37,6 +38,11 @@ func PrefixMatcher(strs ...string) Matcher {
|
||||
return pt.matchPrefix
|
||||
}
|
||||
|
||||
func prefixByteMatcher(list ...[]byte) Matcher {
|
||||
pt := newPatriciaTree(list...)
|
||||
return pt.matchPrefix
|
||||
}
|
||||
|
||||
var defaultHTTPMethods = []string{
|
||||
"OPTIONS",
|
||||
"GET",
|
||||
@@ -57,6 +63,27 @@ func HTTP1Fast(extMethods ...string) Matcher {
|
||||
return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...)
|
||||
}
|
||||
|
||||
// TLS matches HTTPS requests.
|
||||
//
|
||||
// By default, any TLS handshake packet is matched. An optional whitelist
|
||||
// of versions can be passed in to restrict the matcher, for example:
|
||||
// TLS(tls.VersionTLS11, tls.VersionTLS12)
|
||||
func TLS(versions ...int) Matcher {
|
||||
if len(versions) == 0 {
|
||||
versions = []int{
|
||||
tls.VersionSSL30,
|
||||
tls.VersionTLS10,
|
||||
tls.VersionTLS11,
|
||||
tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
prefixes := [][]byte{}
|
||||
for _, v := range versions {
|
||||
prefixes = append(prefixes, []byte{22, byte(v >> 8 & 0xff), byte(v & 0xff)})
|
||||
}
|
||||
return prefixByteMatcher(prefixes...)
|
||||
}
|
||||
|
||||
const maxHTTPRead = 4096
|
||||
|
||||
// HTTP1 parses the first line or upto 4096 bytes of the request to see if
|
||||
@@ -100,15 +127,41 @@ func HTTP2() Matcher {
|
||||
// request of an HTTP 1 connection.
|
||||
func HTTP1HeaderField(name, value string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP1Field(r, name, value)
|
||||
return matchHTTP1Field(r, name, func(gotValue string) bool {
|
||||
return gotValue == value
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2HeaderField resturns a matcher matching the header fields of the first
|
||||
// HTTP1HeaderFieldPrefix returns a matcher matching the header fields of the
|
||||
// first request of an HTTP 1 connection. If the header with key name has a
|
||||
// value prefixed with valuePrefix, this will match.
|
||||
func HTTP1HeaderFieldPrefix(name, valuePrefix string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP1Field(r, name, func(gotValue string) bool {
|
||||
return strings.HasPrefix(gotValue, valuePrefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2HeaderField returns a matcher matching the header fields of the first
|
||||
// headers frame.
|
||||
func HTTP2HeaderField(name, value string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP2Field(ioutil.Discard, r, name, value)
|
||||
return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool {
|
||||
return gotValue == value
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2HeaderFieldPrefix returns a matcher matching the header fields of the
|
||||
// first headers frame. If the header with key name has a value prefixed with
|
||||
// valuePrefix, this will match.
|
||||
func HTTP2HeaderFieldPrefix(name, valuePrefix string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool {
|
||||
return strings.HasPrefix(gotValue, valuePrefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,37 +170,66 @@ func HTTP2HeaderField(name, value string) Matcher {
|
||||
// does not block on receiving a SETTING frame.
|
||||
func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter {
|
||||
return func(w io.Writer, r io.Reader) bool {
|
||||
return matchHTTP2Field(w, r, name, value)
|
||||
return matchHTTP2Field(w, r, name, func(gotValue string) bool {
|
||||
return gotValue == value
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2MatchHeaderFieldPrefixSendSettings matches the header field prefix
|
||||
// and writes the settings to the server. Prefer HTTP2HeaderFieldPrefix over
|
||||
// this one, if the client does not block on receiving a SETTING frame.
|
||||
func HTTP2MatchHeaderFieldPrefixSendSettings(name, valuePrefix string) MatchWriter {
|
||||
return func(w io.Writer, r io.Reader) bool {
|
||||
return matchHTTP2Field(w, r, name, func(gotValue string) bool {
|
||||
return strings.HasPrefix(gotValue, valuePrefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func hasHTTP2Preface(r io.Reader) bool {
|
||||
var b [len(http2.ClientPreface)]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return false
|
||||
}
|
||||
last := 0
|
||||
|
||||
return string(b[:]) == http2.ClientPreface
|
||||
for {
|
||||
n, err := r.Read(b[last:])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
last += n
|
||||
eq := string(b[:last]) == http2.ClientPreface[:last]
|
||||
if last == len(http2.ClientPreface) {
|
||||
return eq
|
||||
}
|
||||
if !eq {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
|
||||
func matchHTTP1Field(r io.Reader, name string, matches func(string) bool) (matched bool) {
|
||||
req, err := http.ReadRequest(bufio.NewReader(r))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return req.Header.Get(name) == value
|
||||
return matches(req.Header.Get(name))
|
||||
}
|
||||
|
||||
func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool) {
|
||||
func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) bool) (matched bool) {
|
||||
if !hasHTTP2Preface(r) {
|
||||
return false
|
||||
}
|
||||
|
||||
done := false
|
||||
framer := http2.NewFramer(w, r)
|
||||
hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
|
||||
if hf.Name == name && hf.Value == value {
|
||||
matched = true
|
||||
if hf.Name == name {
|
||||
done = true
|
||||
if matches(hf.Value) {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
})
|
||||
for {
|
||||
@@ -161,17 +243,20 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool
|
||||
if err := framer.WriteSettings(); err != nil {
|
||||
return false
|
||||
}
|
||||
case *http2.ContinuationFrame:
|
||||
if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
|
||||
return false
|
||||
}
|
||||
done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0
|
||||
case *http2.HeadersFrame:
|
||||
if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
|
||||
return false
|
||||
}
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0
|
||||
}
|
||||
|
||||
if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 {
|
||||
return false
|
||||
}
|
||||
if done {
|
||||
return matched
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user