diff --git a/processor.go b/processor.go index a50e474..8459f7f 100644 --- a/processor.go +++ b/processor.go @@ -41,6 +41,9 @@ func (p *processor) terminate() { } func (p *processor) start() { + // NOTE: The call to "restore" needs to complete before starting + // the processor goroutine. + p.restore() go func() { for { select { @@ -92,3 +95,12 @@ func (p *processor) exec() { } }(task) } + +// restore moves all tasks from "in-progress" back to queue +// to restore all unfinished tasks. +func (p *processor) restore() { + err := p.rdb.moveAll(inProgress, defaultQueue) + if err != nil { + log.Printf("[SERVER ERROR] could not move tasks from %q to %q\n", inProgress, defaultQueue) + } +} diff --git a/rdb.go b/rdb.go index ee50e43..3be97d5 100644 --- a/rdb.go +++ b/rdb.go @@ -171,3 +171,16 @@ func (r *rdb) listQueues() []string { } return queues } + +// moveAll moves all tasks from src list to dst list. +func (r *rdb) moveAll(src, dst string) error { + // TODO(hibiken): Lua script + txf := func(tx *redis.Tx) error { + length := tx.LLen(src).Val() + for i := 0; i < int(length); i++ { + tx.RPopLPush(src, dst) + } + return nil + } + return r.client.Watch(txf, src) +} diff --git a/rdb_test.go b/rdb_test.go index fd5410e..cec2de9 100644 --- a/rdb_test.go +++ b/rdb_test.go @@ -2,6 +2,7 @@ package asynq import ( "encoding/json" + "math/rand" "testing" "time" @@ -12,6 +13,10 @@ import ( var client *redis.Client +func init() { + rand.Seed(time.Now().UnixNano()) +} + // setup connects to a redis database and flush all keys // before returning an instance of rdb. func setup() *rdb { @@ -26,14 +31,18 @@ func setup() *rdb { return newRDB(client) } +func randomTask(taskType, qname string) *taskMessage { + return &taskMessage{ + ID: uuid.New(), + Type: taskType, + Queue: qname, + Retry: rand.Intn(100), + } +} + func TestPush(t *testing.T) { r := setup() - msg := &taskMessage{ - ID: uuid.New(), - Type: "sendEmail", - Queue: "default", - Retry: 10, - } + msg := randomTask("send_email", "default") err := r.push(msg) if err != nil { @@ -55,12 +64,7 @@ func TestPush(t *testing.T) { func TestDequeueImmediateReturn(t *testing.T) { r := setup() - msg := &taskMessage{ - ID: uuid.New(), - Type: "GenerateCSVExport", - Queue: "csv", - Retry: 10, - } + msg := randomTask("export_csv", "csv") r.push(msg) res, err := r.dequeue("asynq:queues:csv", time.Second) @@ -92,3 +96,33 @@ func TestDequeueTimeout(t *testing.T) { t.Errorf("err = %v, want %v", err, errQueuePopTimeout) } } + +func TestMoveAll(t *testing.T) { + r := setup() + seed := []*taskMessage{ + randomTask("send_email", "default"), + randomTask("export_csv", "csv"), + randomTask("sync_stuff", "sync"), + } + for _, task := range seed { + bytes, err := json.Marshal(task) + if err != nil { + t.Errorf("json.Marhsal() failed: %v", err) + } + if err := client.LPush(inProgress, string(bytes)).Err(); err != nil { + t.Errorf("LPUSH %q %s failed: %v", inProgress, string(bytes), err) + } + } + + err := r.moveAll(inProgress, defaultQueue) + if err != nil { + t.Errorf("moveAll failed: %v", err) + } + + if l := client.LLen(inProgress).Val(); l != 0 { + t.Errorf("LLEN %q = %d, want 0", inProgress, l) + } + if l := client.LLen(defaultQueue).Val(); int(l) != len(seed) { + t.Errorf("LLEN %q = %d, want %d", defaultQueue, l, len(seed)) + } +}