diff --git a/rdb_test.go b/rdb_test.go index ba0416b..4e27f5e 100644 --- a/rdb_test.go +++ b/rdb_test.go @@ -18,6 +18,12 @@ func init() { rand.Seed(time.Now().UnixNano()) } +var sortStrOpt = cmp.Transformer("SortStr", func(in []string) []string { + out := append([]string(nil), in...) // Copy input to avoid mutating it + sort.Strings(out) + return out +}) + // setup connects to a redis database and flush all keys // before returning an instance of rdb. func setup(t *testing.T) *rdb { @@ -118,31 +124,76 @@ func TestDequeue(t *testing.T) { func TestMoveAll(t *testing.T) { r := setup(t) - seed := []*taskMessage{ - randomTask("send_email", "default", nil), - randomTask("export_csv", "csv", nil), - randomTask("sync_stuff", "sync", nil), - } - for _, task := range seed { - bytes, err := json.Marshal(task) - if err != nil { - t.Fatal(err) - } - if err := client.LPush(inProgress, string(bytes)).Err(); err != nil { - t.Fatal(err) - } - } - - err := r.moveAll(inProgress, defaultQueue) + t1 := randomTask("send_email", "default", nil) + t2 := randomTask("export_csv", "csv", nil) + t3 := randomTask("sync_stuff", "sync", nil) + json1, err := json.Marshal(t1) if err != nil { - t.Errorf("moveAll(%q, %q) = %v, want nil", inProgress, defaultQueue, err) + t.Fatal(err) + } + json2, err := json.Marshal(t2) + if err != nil { + t.Fatal(err) + } + json3, err := json.Marshal(t3) + if err != nil { + t.Fatal(err) } - if l := client.LLen(inProgress).Val(); l != 0 { - t.Errorf("LLEN %q = %d, want 0", inProgress, l) + tests := []struct { + beforeSrc []string + beforeDst []string + afterSrc []string + afterDst []string + }{ + { + beforeSrc: []string{string(json1), string(json2), string(json3)}, + beforeDst: []string{}, + afterSrc: []string{}, + afterDst: []string{string(json1), string(json2), string(json3)}, + }, + { + beforeSrc: []string{}, + beforeDst: []string{string(json1), string(json2), string(json3)}, + afterSrc: []string{}, + afterDst: []string{string(json1), string(json2), string(json3)}, + }, + { + beforeSrc: []string{string(json2), string(json3)}, + beforeDst: []string{string(json1)}, + afterSrc: []string{}, + afterDst: []string{string(json1), string(json2), string(json3)}, + }, } - if l := client.LLen(defaultQueue).Val(); int(l) != len(seed) { - t.Errorf("LLEN %q = %d, want %d", defaultQueue, l, len(seed)) + + for _, tc := range tests { + // clean up db before each test case. + if err := client.FlushDB().Err(); err != nil { + t.Error(err) + continue + } + // seed src list. + for _, msg := range tc.beforeSrc { + client.LPush(inProgress, msg) + } + // seed dst list. + for _, msg := range tc.beforeDst { + client.LPush(defaultQueue, msg) + } + + if err := r.moveAll(inProgress, defaultQueue); err != nil { + t.Errorf("(*rdb).moveAll(%q, %q) = %v, want nil", inProgress, defaultQueue, err) + continue + } + + gotSrc := client.LRange(inProgress, 0, -1).Val() + if diff := cmp.Diff(tc.afterSrc, gotSrc, sortStrOpt); diff != "" { + t.Errorf("mismatch found in %q (-want, +got)\n%s", inProgress, diff) + } + gotDst := client.LRange(defaultQueue, 0, -1).Val() + if diff := cmp.Diff(tc.afterDst, gotDst, sortStrOpt); diff != "" { + t.Errorf("mismatch found in %q (-want, +got)\n%s", defaultQueue, diff) + } } } @@ -189,12 +240,6 @@ func TestForward(t *testing.T) { }, } - sortStrOpt := cmp.Transformer("SortStr", func(in []string) []string { - out := append([]string(nil), in...) // Copy input to avoid mutating it - sort.Strings(out) - return out - }) - for _, tc := range tests { // clean up db before each test case. if err := client.FlushDB().Err(); err != nil {