diff --git a/heartbeat.go b/heartbeat.go index 4e13faf..556a219 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -51,7 +51,7 @@ func (h *heartbeater) start(wg *sync.WaitGroup) { for { select { case <-h.done: - h.rdb.ClearProcessInfo(h.ps.Get()) + h.rdb.ClearProcessState(h.ps) logger.info("Heartbeater done") return case <-time.After(h.interval): @@ -64,7 +64,7 @@ func (h *heartbeater) start(wg *sync.WaitGroup) { func (h *heartbeater) beat() { // Note: Set TTL to be long enough so that it won't expire before we write again // and short enough to expire quickly once the process is shut down or killed. - err := h.rdb.WriteProcessInfo(h.ps.Get(), h.interval*2) + err := h.rdb.WriteProcessState(h.ps, h.interval*2) if err != nil { logger.error("could not write heartbeat data: %v", err) } diff --git a/internal/base/base.go b/internal/base/base.go index 809beca..2f9ab34 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -22,6 +22,7 @@ const DefaultQueueName = "default" const ( AllProcesses = "asynq:ps" // ZSET psPrefix = "asynq:ps:" // STRING - asynq:ps:: + AllWorkers = "asynq:workers" // ZSET workersPrefix = "asynq:workers:" // HASH - asynq:workers: processedPrefix = "asynq:processed:" // STRING - asynq:processed: failurePrefix = "asynq:failure:" // STRING - asynq:failure: diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 8f5292f..48bcb56 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -2054,32 +2054,51 @@ func TestRemoveQueueError(t *testing.T) { func TestListProcesses(t *testing.T) { r := setup(t) - ps1 := &base.ProcessInfo{ + started1 := time.Now().Add(-time.Hour) + ps1 := base.NewProcessState("do.droplet1", 1234, 10, map[string]int{"default": 1}, false) + ps1.SetStarted(started1) + ps1.SetStatus(base.StatusRunning) + info1 := &base.ProcessInfo{ Concurrency: 10, Queues: map[string]int{"default": 1}, Host: "do.droplet1", PID: 1234, Status: "running", - Started: time.Now().Add(-time.Hour), - ActiveWorkerCount: 5, + Started: started1, + ActiveWorkerCount: 0, } - ps2 := &base.ProcessInfo{ + started2 := time.Now().Add(-2 * time.Hour) + ps2 := base.NewProcessState("do.droplet2", 9876, 20, map[string]int{"email": 1}, false) + ps2.SetStarted(started2) + ps2.SetStatus(base.StatusStopped) + ps2.AddWorkerStats(h.NewTaskMessage("send_email", nil), time.Now()) + info2 := &base.ProcessInfo{ Concurrency: 20, Queues: map[string]int{"email": 1}, Host: "do.droplet2", PID: 9876, Status: "stopped", - Started: time.Now().Add(-2 * time.Hour), - ActiveWorkerCount: 20, + Started: started2, + ActiveWorkerCount: 1, } tests := []struct { - processes []*base.ProcessInfo + processes []*base.ProcessState + want []*base.ProcessInfo }{ - {processes: []*base.ProcessInfo{}}, - {processes: []*base.ProcessInfo{ps1}}, - {processes: []*base.ProcessInfo{ps1, ps2}}, + { + processes: []*base.ProcessState{}, + want: []*base.ProcessInfo{}, + }, + { + processes: []*base.ProcessState{ps1}, + want: []*base.ProcessInfo{info1}, + }, + { + processes: []*base.ProcessState{ps1, ps2}, + want: []*base.ProcessInfo{info1, info2}, + }, } ignoreOpt := cmpopts.IgnoreUnexported(base.ProcessInfo{}) @@ -2088,7 +2107,7 @@ func TestListProcesses(t *testing.T) { h.FlushDB(t, r.client) for _, ps := range tc.processes { - if err := r.WriteProcessInfo(ps, 5*time.Second); err != nil { + if err := r.WriteProcessState(ps, 5*time.Second); err != nil { t.Fatal(err) } } @@ -2097,7 +2116,7 @@ func TestListProcesses(t *testing.T) { if err != nil { t.Errorf("r.ListProcesses returned an error: %v", err) } - if diff := cmp.Diff(tc.processes, got, h.SortProcessInfoOpt, ignoreOpt); diff != "" { + if diff := cmp.Diff(tc.want, got, h.SortProcessInfoOpt, ignoreOpt); diff != "" { t.Errorf("r.ListProcesses returned %v, want %v; (-want,+got)\n%s", got, tc.processes, diff) } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 3c166da..faf12ed 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -359,56 +359,70 @@ func (r *RDB) forwardSingle(src, dst string) error { []string{src, dst}, now).Err() } -// KEYS[1] -> asynq:ps -// KEYS[2] -> asynq:ps: -// ARGV[1] -> expiration time -// ARGV[2] -> TTL in seconds -// ARGV[3] -> process info +// KEYS[1] -> asynq:ps: +// KEYS[2] -> asynq:ps +// KEYS[3] -> asynq:workers +// keys[4] -> asynq:workers +// ARGV[1] -> expiration time +// ARGV[2] -> TTL in seconds +// ARGV[3] -> process info +// ARGV[4:] -> alternate key-value pair of (worker id, worker data) +// Note: Add key to ZSET with expiration time as score. +// ref: https://github.com/antirez/redis/issues/135#issuecomment-2361996 var writeProcessInfoCmd = redis.NewScript(` -redis.call("ZADD", KEYS[1], ARGV[1], KEYS[2]) -redis.call("SETEX", KEYS[2], ARGV[2], ARGV[3]) +redis.call("SETEX", KEYS[1], ARGV[2], ARGV[3]) +redis.call("ZADD", KEYS[2], ARGV[1], KEYS[1]) +for i = 4, table.getn(ARGV)-1, 2 do + redis.call("HSET", KEYS[3], ARGV[i], ARGV[i+1]) +end +redis.call("EXPIRE", KEYS[3], ARGV[2]) +redis.call("ZADD", KEYS[4], ARGV[1], KEYS[3]) return redis.status_reply("OK")`) -// WriteProcessInfo writes process information to redis with expiration -// set to the value ttl. -func (r *RDB) WriteProcessInfo(ps *base.ProcessInfo, ttl time.Duration) error { - bytes, err := json.Marshal(ps) +// WriteProcessState writes process state data to redis with expiration set to the value ttl. +func (r *RDB) WriteProcessState(ps *base.ProcessState, ttl time.Duration) error { + info := ps.Get() + bytes, err := json.Marshal(info) if err != nil { return err } - // Note: Add key to ZSET with expiration time as score. - // ref: https://github.com/antirez/redis/issues/135#issuecomment-2361996 + var args []interface{} // args to the lua script exp := time.Now().Add(ttl).UTC() - key := base.ProcessInfoKey(ps.Host, ps.PID) - return writeProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}, float64(exp.Unix()), ttl.Seconds(), string(bytes)).Err() -} - -// ReadProcessInfo reads process information stored in redis. -func (r *RDB) ReadProcessInfo(host string, pid int) (*base.ProcessInfo, error) { - key := base.ProcessInfoKey(host, pid) - data, err := r.client.Get(key).Result() - if err != nil { - return nil, err + workers := ps.GetWorkers() + args = append(args, float64(exp.Unix()), ttl.Seconds(), bytes) + for _, w := range workers { + bytes, err := json.Marshal(w) + if err != nil { + continue // skip bad data + } + args = append(args, w.ID.String(), bytes) } - var pinfo base.ProcessInfo - err = json.Unmarshal([]byte(data), &pinfo) - if err != nil { - return nil, err - } - return &pinfo, nil + pkey := base.ProcessInfoKey(info.Host, info.PID) + wkey := base.WorkersKey(info.Host, info.PID) + return writeProcessInfoCmd.Run(r.client, + []string{pkey, base.AllProcesses, wkey, base.AllWorkers}, + args...).Err() } // KEYS[1] -> asynq:ps // KEYS[2] -> asynq:ps: +// KEYS[3] -> asynq:workers +// KEYS[4] -> asynq:workers var clearProcessInfoCmd = redis.NewScript(` redis.call("ZREM", KEYS[1], KEYS[2]) redis.call("DEL", KEYS[2]) +redis.call("ZREM", KEYS[3], KEYS[4]) +redis.call("DEL", KEYS[4]) return redis.status_reply("OK")`) -// ClearProcessInfo deletes process information from redis. -func (r *RDB) ClearProcessInfo(ps *base.ProcessInfo) error { - key := base.ProcessInfoKey(ps.Host, ps.PID) - return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err() +// ClearProcessState deletes process state data from redis. +func (r *RDB) ClearProcessState(ps *base.ProcessState) error { + info := ps.Get() + host, pid := info.Host, info.PID + pkey := base.ProcessInfoKey(host, pid) + wkey := base.WorkersKey(host, pid) + return clearProcessInfoCmd.Run(r.client, + []string{base.AllProcesses, pkey, base.AllWorkers, wkey}).Err() } // CancelationPubSub returns a pubsub for cancelation messages. diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 6dcdc93..e0b4669 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -5,8 +5,8 @@ package rdb import ( + "encoding/json" "fmt" - "strconv" "testing" "time" @@ -741,81 +741,237 @@ func TestCheckAndEnqueue(t *testing.T) { } } -func TestReadWriteClearProcessInfo(t *testing.T) { +func TestWriteProcessState(t *testing.T) { r := setup(t) - pinfo := &base.ProcessInfo{ + host, pid := "localhost", 98765 + queues := map[string]int{"default": 2, "email": 5, "low": 1} + + started := time.Now() + ps := base.NewProcessState(host, pid, 10, queues, false) + ps.SetStarted(started) + ps.SetStatus(base.StatusRunning) + ttl := 5 * time.Second + + h.FlushDB(t, r.client) + + err := r.WriteProcessState(ps, ttl) + if err != nil { + t.Errorf("r.WriteProcessState returned an error: %v", err) + } + + // Check ProcessInfo was written correctly + pkey := base.ProcessInfoKey(host, pid) + data := r.client.Get(pkey).Val() + var got base.ProcessInfo + err = json.Unmarshal([]byte(data), &got) + if err != nil { + t.Fatalf("could not decode json: %v", err) + } + want := base.ProcessInfo{ + Host: "localhost", + PID: 98765, Concurrency: 10, Queues: map[string]int{"default": 2, "email": 5, "low": 1}, - PID: 98765, - Host: "localhost", + StrictPriority: false, Status: "running", - Started: time.Now(), - ActiveWorkerCount: 1, + Started: started, + ActiveWorkerCount: 0, + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("persisted ProcessInfo was %v, want %v; (-want,+got)\n%s", + got, want, diff) + } + // Check ProcessInfo TTL was set correctly + gotTTL := r.client.TTL(pkey).Val() + timeCmpOpt := cmpopts.EquateApproxTime(time.Second) + if !cmp.Equal(ttl, gotTTL, timeCmpOpt) { + t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl) + } + // Check ProcessInfo key was added to the set correctly + gotProcesses := r.client.ZRange(base.AllProcesses, 0, -1).Val() + wantProcesses := []string{pkey} + if diff := cmp.Diff(wantProcesses, gotProcesses); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllProcesses, gotProcesses, wantProcesses) } - tests := []struct { - pi *base.ProcessInfo - ttl time.Duration - }{ - {pinfo, 5 * time.Second}, + // Check WorkersInfo was written correctly + wkey := base.WorkersKey(host, pid) + workerExist := r.client.Exists(wkey).Val() + if workerExist != 0 { + t.Errorf("%q key exists", wkey) } - - for _, tc := range tests { - h.FlushDB(t, r.client) - - err := r.WriteProcessInfo(tc.pi, tc.ttl) - if err != nil { - t.Errorf("r.WriteProcessInfo returned an error: %v", err) - continue - } - - got, err := r.ReadProcessInfo(tc.pi.Host, tc.pi.PID) - if err != nil { - t.Errorf("r.ReadProcessInfo returned an error: %v", err) - continue - } - - ignoreOpt := cmpopts.IgnoreUnexported(base.ProcessInfo{}) - if diff := cmp.Diff(tc.pi, got, ignoreOpt); diff != "" { - t.Errorf("r.ReadProcessInfo(%q, %d) = %+v, want %+v; (-want,+got)\n%s", - tc.pi.Host, tc.pi.PID, got, tc.pi, diff) - } - - key := base.ProcessInfoKey(tc.pi.Host, tc.pi.PID) - gotTTL := r.client.TTL(key).Val() - if !cmp.Equal(tc.ttl, gotTTL, timeCmpOpt) { - t.Errorf("redis TTL %q returned %v, want %v", key, gotTTL, tc.ttl) - } - - now := time.Now().UTC() - allKeys, err := r.client.ZRangeByScore(base.AllProcesses, &redis.ZRangeBy{ - Min: strconv.Itoa(int(now.Unix())), - Max: "+inf", - }).Result() - if err != nil { - t.Errorf("redis ZRANGEBYSCORE %q %d +inf returned an error: %v", - base.AllProcesses, now.Unix(), err) - continue - } - - wantAllKeys := []string{key} - if diff := cmp.Diff(wantAllKeys, allKeys); diff != "" { - t.Errorf("all keys = %v, want %v; (-want,+got)\n%s", allKeys, wantAllKeys, diff) - } - - if err := r.ClearProcessInfo(tc.pi); err != nil { - t.Errorf("r.ClearProcessInfo returned an error: %v", err) - continue - } - - // 1 means key exists - if r.client.Exists(key).Val() == 1 { - t.Errorf("expected %q to be deleted", key) - } - - if r.client.ZCard(base.AllProcesses).Val() != 0 { - t.Errorf("expected %q to be empty", base.AllProcesses) - } - + // Check WorkersInfo key was added to the set correctly + gotWorkerKeys := r.client.ZRange(base.AllWorkers, 0, -1).Val() + wantWorkerKeys := []string{wkey} + if diff := cmp.Diff(wantWorkerKeys, gotWorkerKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllWorkers, gotWorkerKeys, wantWorkerKeys) + } +} + +func TestWriteProcessStateWithWorkers(t *testing.T) { + r := setup(t) + host, pid := "localhost", 98765 + queues := map[string]int{"default": 2, "email": 5, "low": 1} + concurrency := 10 + + started := time.Now().Add(-10 * time.Minute) + w1Started := time.Now().Add(-time.Minute) + w2Started := time.Now().Add(-time.Second) + msg1 := h.NewTaskMessage("send_email", map[string]interface{}{"user_id": "123"}) + msg2 := h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/imgfile"}) + ps := base.NewProcessState(host, pid, concurrency, queues, false) + ps.SetStarted(started) + ps.SetStatus(base.StatusRunning) + ps.AddWorkerStats(msg1, w1Started) + ps.AddWorkerStats(msg2, w2Started) + ttl := 5 * time.Second + + h.FlushDB(t, r.client) + + err := r.WriteProcessState(ps, ttl) + if err != nil { + t.Errorf("r.WriteProcessState returned an error: %v", err) + } + + // Check ProcessInfo was written correctly + pkey := base.ProcessInfoKey(host, pid) + data := r.client.Get(pkey).Val() + var got base.ProcessInfo + err = json.Unmarshal([]byte(data), &got) + if err != nil { + t.Fatalf("could not decode json: %v", err) + } + want := base.ProcessInfo{ + Host: host, + PID: pid, + Concurrency: concurrency, + Queues: queues, + StrictPriority: false, + Status: "running", + Started: started, + ActiveWorkerCount: 2, + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("persisted ProcessInfo was %v, want %v; (-want,+got)\n%s", + got, want, diff) + } + // Check ProcessInfo TTL was set correctly + gotTTL := r.client.TTL(pkey).Val() + timeCmpOpt := cmpopts.EquateApproxTime(time.Second) + if !cmp.Equal(ttl, gotTTL, timeCmpOpt) { + t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl) + } + // Check ProcessInfo key was added to the set correctly + gotProcesses := r.client.ZRange(base.AllProcesses, 0, -1).Val() + wantProcesses := []string{pkey} + if diff := cmp.Diff(wantProcesses, gotProcesses); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllProcesses, gotProcesses, wantProcesses) + } + + // Check WorkersInfo was written correctly + wkey := base.WorkersKey(host, pid) + wdata := r.client.HGetAll(wkey).Val() + if len(wdata) != 2 { + t.Fatalf("HGETALL %q returned a hash of size %d, want 2", wkey, len(wdata)) + } + gotWorkers := make(map[string]*base.WorkerInfo) + for key, val := range wdata { + var w base.WorkerInfo + if err := json.Unmarshal([]byte(val), &w); err != nil { + t.Fatalf("could not unmarshal worker's data: %v", err) + } + gotWorkers[key] = &w + } + wantWorkers := map[string]*base.WorkerInfo{ + msg1.ID.String(): &base.WorkerInfo{ + Host: host, + PID: pid, + ID: msg1.ID, + Type: msg1.Type, + Queue: msg1.Queue, + Payload: msg1.Payload, + Started: w1Started, + }, + msg2.ID.String(): &base.WorkerInfo{ + Host: host, + PID: pid, + ID: msg2.ID, + Type: msg2.Type, + Queue: msg2.Queue, + Payload: msg2.Payload, + Started: w2Started, + }, + } + if diff := cmp.Diff(wantWorkers, gotWorkers); diff != "" { + t.Errorf("persisted workers info was %v, want %v; (-want,+got)\n%s", + gotWorkers, wantWorkers, diff) + } + + // Check WorkersInfo TTL was set correctly + gotTTL = r.client.TTL(wkey).Val() + if !cmp.Equal(ttl, gotTTL, timeCmpOpt) { + t.Errorf("TTL of %q was %v, want %v", wkey, gotTTL, ttl) + } + // Check WorkersInfo key was added to the set correctly + gotWorkerKeys := r.client.ZRange(base.AllWorkers, 0, -1).Val() + wantWorkerKeys := []string{wkey} + if diff := cmp.Diff(wantWorkerKeys, gotWorkerKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllWorkers, gotWorkerKeys, wantWorkerKeys) + } +} + +func TestClearProcessState(t *testing.T) { + r := setup(t) + host, pid := "127.0.0.1", 1234 + + h.FlushDB(t, r.client) + + pkey := base.ProcessInfoKey(host, pid) + wkey := base.WorkersKey(host, pid) + otherPKey := base.ProcessInfoKey("otherhost", 12345) + otherWKey := base.WorkersKey("otherhost", 12345) + // Populate the keys. + if err := r.client.Set(pkey, "process-info", 0).Err(); err != nil { + t.Fatal(err) + } + if err := r.client.HSet(wkey, "worker-key", "worker-info").Err(); err != nil { + t.Fatal(err) + } + if err := r.client.ZAdd(base.AllProcesses, &redis.Z{Member: pkey}).Err(); err != nil { + t.Fatal(err) + } + if err := r.client.ZAdd(base.AllProcesses, &redis.Z{Member: otherPKey}).Err(); err != nil { + t.Fatal(err) + } + if err := r.client.ZAdd(base.AllWorkers, &redis.Z{Member: wkey}).Err(); err != nil { + t.Fatal(err) + } + if err := r.client.ZAdd(base.AllWorkers, &redis.Z{Member: otherWKey}).Err(); err != nil { + t.Fatal(err) + } + + ps := base.NewProcessState(host, pid, 10, map[string]int{"default": 1}, false) + + err := r.ClearProcessState(ps) + if err != nil { + t.Fatalf("(*RDB).ClearProcessState failed: %v", err) + } + + // Check all keys are cleared + if r.client.Exists(pkey).Val() != 0 { + t.Errorf("Redis key %q exists", pkey) + } + if r.client.Exists(wkey).Val() != 0 { + t.Errorf("Redis key %q exists", wkey) + } + gotProcessKeys := r.client.ZRange(base.AllProcesses, 0, -1).Val() + wantProcessKeys := []string{otherPKey} + if diff := cmp.Diff(wantProcessKeys, gotProcessKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllProcesses, gotProcessKeys, wantProcessKeys) + } + gotWorkerKeys := r.client.ZRange(base.AllWorkers, 0, -1).Val() + wantWorkerKeys := []string{otherWKey} + if diff := cmp.Diff(wantWorkerKeys, gotWorkerKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllWorkers, gotWorkerKeys, wantWorkerKeys) } }