From 4595bd41c3bb7deab60d425f2d1bde6caaa621c9 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Wed, 3 Jun 2020 06:44:12 -0700 Subject: [PATCH] Add Pause and Unpause methods to rdb --- internal/base/base.go | 1 + internal/rdb/inspect.go | 30 +++++++ internal/rdb/inspect_test.go | 161 +++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/internal/base/base.go b/internal/base/base.go index e6bd364..0bb8465 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -34,6 +34,7 @@ const ( RetryQueue = "asynq:retry" // ZSET DeadQueue = "asynq:dead" // ZSET InProgressQueue = "asynq:in_progress" // LIST + PausedQueues = "asynq:paused" // SET CancelChannel = "asynq:cancel" // PubSub channel ) diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 79c1c3e..99f30a4 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -830,3 +830,33 @@ func (r *RDB) ListWorkers() ([]*base.WorkerInfo, error) { } return workers, nil } + +// KEYS[1] -> asynq:paused +// ARGV[1] -> asynq:queues: - queue to pause +var pauseCmd = redis.NewScript(` +local ismem = redis.call("SISMEMBER", KEYS[1], ARGV[1]) +if ismem == 1 then + return redis.error_reply("queue is already paused") +end +return redis.call("SADD", KEYS[1], ARGV[1])`) + +// Pause pauses processing of tasks from the given queue. +func (r *RDB) Pause(qname string) error { + qkey := base.QueueKey(qname) + return pauseCmd.Run(r.client, []string{base.PausedQueues}, qkey).Err() +} + +// KEYS[1] -> asynq:paused +// ARGV[1] -> asynq:queues: - queue to unpause +var unpauseCmd = redis.NewScript(` +local ismem = redis.call("SISMEMBER", KEYS[1], ARGV[1]) +if ismem == 0 then + return redis.error_reply("queue is not paused") +end +return redis.call("SREM", KEYS[1], ARGV[1])`) + +// Unpause resumes processing of tasks from the given queue. +func (r *RDB) Unpause(qname string) error { + qkey := base.QueueKey(qname) + return unpauseCmd.Run(r.client, []string{base.PausedQueues}, qkey).Err() +} diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index c0b2c09..986e254 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -2156,3 +2156,164 @@ func TestListWorkers(t *testing.T) { } } } + +func TestPause(t *testing.T) { + r := setup(t) + + tests := []struct { + initial []string // initial queue keys in the set + qname string // queue name to pause + want []string // expected queue keys in the set + }{ + {[]string{}, "default", []string{"asynq:queues:default"}}, + {[]string{"asynq:queues:default"}, "critical", []string{"asynq:queues:default", "asynq:queues:critical"}}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + // Set up initial state. + for _, qkey := range tc.initial { + if err := r.client.SAdd(base.PausedQueues, qkey).Err(); err != nil { + t.Fatal(err) + } + } + + err := r.Pause(tc.qname) + if err != nil { + t.Errorf("Pause(%q) returned error: %v", tc.qname, err) + continue + } + + got, err := r.client.SMembers(base.PausedQueues).Result() + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.want, got, h.SortStringSliceOpt); diff != "" { + t.Errorf("%q has members %v, want %v; (-want,+got)\n%s", + base.PausedQueues, got, tc.want, diff) + } + } +} + +func TestPauseError(t *testing.T) { + r := setup(t) + + tests := []struct { + desc string // test case description + initial []string // initial queue keys in the set + qname string // queue name to pause + want []string // expected queue keys in the set + }{ + {"queue already paused", []string{"asynq:queues:default"}, "default", []string{"asynq:queues:default"}}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + // Set up initial state. + for _, qkey := range tc.initial { + if err := r.client.SAdd(base.PausedQueues, qkey).Err(); err != nil { + t.Fatal(err) + } + } + + err := r.Pause(tc.qname) + if err == nil { + t.Errorf("%s; Pause(%q) returned nil: want error", tc.desc, tc.qname) + continue + } + + got, err := r.client.SMembers(base.PausedQueues).Result() + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.want, got, h.SortStringSliceOpt); diff != "" { + t.Errorf("%s; %q has members %v, want %v; (-want,+got)\n%s", + tc.desc, base.PausedQueues, got, tc.want, diff) + } + } +} + +func TestUnpause(t *testing.T) { + r := setup(t) + + tests := []struct { + initial []string // initial queue keys in the set + qname string // queue name to unpause + want []string // expected queue keys in the set + }{ + {[]string{"asynq:queues:default"}, "default", []string{}}, + {[]string{"asynq:queues:default", "asynq:queues:low"}, "low", []string{"asynq:queues:default"}}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + // Set up initial state. + for _, qkey := range tc.initial { + if err := r.client.SAdd(base.PausedQueues, qkey).Err(); err != nil { + t.Fatal(err) + } + } + + err := r.Unpause(tc.qname) + if err != nil { + t.Errorf("Unpause(%q) returned error: %v", tc.qname, err) + continue + } + + got, err := r.client.SMembers(base.PausedQueues).Result() + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.want, got, h.SortStringSliceOpt); diff != "" { + t.Errorf("%q has members %v, want %v; (-want,+got)\n%s", + base.PausedQueues, got, tc.want, diff) + } + } +} + +func TestUnpauseError(t *testing.T) { + r := setup(t) + + tests := []struct { + desc string // test case description + initial []string // initial queue keys in the set + qname string // queue name to unpause + want []string // expected queue keys in the set + }{ + {"set is empty", []string{}, "default", []string{}}, + {"queue is not in the set", []string{"asynq:queues:default"}, "low", []string{"asynq:queues:default"}}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + // Set up initial state. + for _, qkey := range tc.initial { + if err := r.client.SAdd(base.PausedQueues, qkey).Err(); err != nil { + t.Fatal(err) + } + } + + err := r.Unpause(tc.qname) + if err == nil { + t.Errorf("%s; Unpause(%q) returned nil: want error", tc.desc, tc.qname) + continue + } + + got, err := r.client.SMembers(base.PausedQueues).Result() + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.want, got, h.SortStringSliceOpt); diff != "" { + t.Errorf("%s; %q has members %v, want %v; (-want,+got)\n%s", + tc.desc, base.PausedQueues, got, tc.want, diff) + } + } +}