diff --git a/asynq.go b/asynq.go index 234d4b9..7f5a630 100644 --- a/asynq.go +++ b/asynq.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "fmt" + "net" "net/url" "strconv" "strings" @@ -430,7 +431,7 @@ func ParseRedisURI(uri string) (RedisConnOpt, error) { return nil, fmt.Errorf("asynq: could not parse redis uri: %v", err) } switch u.Scheme { - case "redis": + case "redis", "rediss": return parseRedisURI(u) case "redis-socket": return parseRedisSocketURI(u) @@ -444,6 +445,8 @@ func ParseRedisURI(uri string) (RedisConnOpt, error) { func parseRedisURI(u *url.URL) (RedisConnOpt, error) { var db int var err error + var redisConnOpt RedisClientOpt + if len(u.Path) > 0 { xs := strings.Split(strings.Trim(u.Path, "/"), "/") db, err = strconv.Atoi(xs[0]) @@ -455,7 +458,20 @@ func parseRedisURI(u *url.URL) (RedisConnOpt, error) { if v, ok := u.User.Password(); ok { password = v } - return RedisClientOpt{Addr: u.Host, DB: db, Password: password}, nil + + if u.Scheme == "rediss" { + h, _, err := net.SplitHostPort(u.Host) + if err != nil { + h = u.Host + } + redisConnOpt.TLSConfig = &tls.Config{ServerName: h} + } + + redisConnOpt.Addr = u.Host + redisConnOpt.Password = password + redisConnOpt.DB = db + + return redisConnOpt, nil } func parseRedisSocketURI(u *url.URL) (RedisConnOpt, error) { diff --git a/asynq_test.go b/asynq_test.go index e5fdd62..d1081a8 100644 --- a/asynq_test.go +++ b/asynq_test.go @@ -5,6 +5,7 @@ package asynq import ( + "crypto/tls" "flag" "sort" "strings" @@ -12,6 +13,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" h "github.com/hibiken/asynq/internal/asynqtest" "github.com/hibiken/asynq/internal/log" ) @@ -99,6 +101,10 @@ func TestParseRedisURI(t *testing.T) { "redis://localhost:6379", RedisClientOpt{Addr: "localhost:6379"}, }, + { + "rediss://localhost:6379", + RedisClientOpt{Addr: "localhost:6379", TLSConfig: &tls.Config{ServerName: "localhost"}}, + }, { "redis://localhost:6379/3", RedisClientOpt{Addr: "localhost:6379", DB: 3}, @@ -151,7 +157,7 @@ func TestParseRedisURI(t *testing.T) { continue } - if diff := cmp.Diff(tc.want, got); diff != "" { + if diff := cmp.Diff(tc.want, got, cmpopts.IgnoreUnexported(tls.Config{})); diff != "" { t.Errorf("ParseRedisURI(%q) = %+v, want %+v\n(-want,+got)\n%s", tc.uri, got, tc.want, diff) } }