diff --git a/background.go b/background.go index 2a01f96..550a887 100644 --- a/background.go +++ b/background.go @@ -33,10 +33,11 @@ type Background struct { mu sync.Mutex running bool - rdb *rdb.RDB - scheduler *scheduler - processor *processor - syncer *syncer + rdb *rdb.RDB + scheduler *scheduler + processor *processor + syncer *syncer + heartbeater *heartbeater } // Config specifies the background-task processing behavior. @@ -109,16 +110,24 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background { } qcfg := normalizeQueueCfg(queues) + host, err := os.Hostname() + if err != nil { + host = "unknown-host" + } + pid := os.Getpid() + + rdb := rdb.NewRDB(createRedisClient(r)) syncRequestCh := make(chan *syncRequest) syncer := newSyncer(syncRequestCh, 5*time.Second) - rdb := rdb.NewRDB(createRedisClient(r)) + heartbeater := newHeartbeater(rdb, 5*time.Second, host, pid, queues, n) scheduler := newScheduler(rdb, 5*time.Second, qcfg) processor := newProcessor(rdb, n, qcfg, cfg.StrictPriority, delayFunc, syncRequestCh) return &Background{ - rdb: rdb, - scheduler: scheduler, - processor: processor, - syncer: syncer, + rdb: rdb, + scheduler: scheduler, + processor: processor, + syncer: syncer, + heartbeater: heartbeater, } } @@ -165,6 +174,7 @@ func (bg *Background) Run(handler Handler) { sig := <-sigs if sig == syscall.SIGTSTP { bg.processor.stop() + bg.heartbeater.setState("stopped") continue } break @@ -184,6 +194,7 @@ func (bg *Background) start(handler Handler) { bg.running = true bg.processor.handler = handler + bg.heartbeater.start() bg.syncer.start() bg.scheduler.start() bg.processor.start() @@ -202,6 +213,7 @@ func (bg *Background) stop() { // Note: processor and all worker goroutines need to be exited // before shutting down syncer to avoid goroutine leak. bg.syncer.terminate() + bg.heartbeater.terminate() bg.rdb.Close() bg.processor.handler = nil diff --git a/heartbeat.go b/heartbeat.go new file mode 100644 index 0000000..d9ead5d --- /dev/null +++ b/heartbeat.go @@ -0,0 +1,78 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "sync" + "time" + + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/rdb" +) + +// heartbeater is responsible for writing process status to redis periodically to +// indicate that the background worker process is up. +type heartbeater struct { + rdb *rdb.RDB + + mu sync.Mutex + ps *base.ProcessStatus + + // channel to communicate back to the long running "heartbeater" goroutine. + done chan struct{} + + // interval between heartbeats. + interval time.Duration +} + +func newHeartbeater(rdb *rdb.RDB, interval time.Duration, host string, pid int, queues map[string]uint, n int) *heartbeater { + ps := &base.ProcessStatus{ + Concurrency: n, + Queues: queues, + Host: host, + PID: pid, + } + return &heartbeater{ + rdb: rdb, + ps: ps, + done: make(chan struct{}), + interval: interval, + } +} + +func (h *heartbeater) terminate() { + logger.info("Heartbeater shutting down...") + // Signal the heartbeater goroutine to stop. + h.done <- struct{}{} +} + +func (h *heartbeater) setState(state string) { + h.mu.Lock() + defer h.mu.Unlock() + h.ps.State = state +} + +func (h *heartbeater) start() { + h.ps.Started = time.Now() + h.ps.State = "running" + go func() { + for { + select { + case <-h.done: + logger.info("Heartbeater done") + return + case <-time.After(h.interval): + // Note: Set TTL to be long enough value so that it won't expire before we write again + // and short enough to expire quickly once process is shut down. + h.mu.Lock() + err := h.rdb.WriteProcessStatus(h.ps, h.interval*2) + h.mu.Unlock() + if err != nil { + logger.error("could not write heartbeat data: %v", err) + } + } + } + }() +} diff --git a/heartbeat_test.go b/heartbeat_test.go new file mode 100644 index 0000000..9aee07f --- /dev/null +++ b/heartbeat_test.go @@ -0,0 +1,86 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + h "github.com/hibiken/asynq/internal/asynqtest" + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/rdb" +) + +func TestHeartbeater(t *testing.T) { + r := setup(t) + rdbClient := rdb.NewRDB(r) + + tests := []struct { + interval time.Duration + host string + pid int + queues map[string]uint + concurrency int + }{ + {time.Second, "some.address.ec2.aws.com", 45678, map[string]uint{"default": 1}, 10}, + } + + for _, tc := range tests { + h.FlushDB(t, r) + + hb := newHeartbeater(rdbClient, tc.interval, tc.host, tc.pid, tc.queues, tc.concurrency) + + want := &base.ProcessStatus{ + Host: tc.host, + PID: tc.pid, + Queues: tc.queues, + Concurrency: tc.concurrency, + Started: time.Now(), + State: "running", + } + hb.start() + + // allow for heartbeater to write to redis + time.Sleep(tc.interval * 2) + + got, err := rdbClient.ReadProcessStatus(tc.host, tc.pid) + if err != nil { + t.Errorf("could not read process status from redis: %v", err) + hb.terminate() + continue + } + + var timeCmpOpt = cmpopts.EquateApproxTime(10 * time.Millisecond) + if diff := cmp.Diff(want, got, timeCmpOpt); diff != "" { + t.Errorf("redis stored process status %+v, want %+v; (-want, +got)\n%s", got, want, diff) + hb.terminate() + continue + } + + // state change + hb.setState("stopped") + + // allow for heartbeater to write to redis + time.Sleep(tc.interval * 2) + + want.State = "stopped" + got, err = rdbClient.ReadProcessStatus(tc.host, tc.pid) + if err != nil { + t.Errorf("could not read process status from redis: %v", err) + hb.terminate() + continue + } + + if diff := cmp.Diff(want, got, timeCmpOpt); diff != "" { + t.Errorf("redis stored process status %+v, want %+v; (-want, +got)\n%s", got, want, diff) + hb.terminate() + continue + } + + hb.terminate() + } +} diff --git a/internal/base/base.go b/internal/base/base.go index 1e33673..87c9dab 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -6,6 +6,7 @@ package base import ( + "fmt" "strings" "time" @@ -17,6 +18,7 @@ const DefaultQueueName = "default" // Redis keys const ( + psPrefix = "asynq:ps:" // HASH processedPrefix = "asynq:processed:" // STRING - asynq:processed: failurePrefix = "asynq:failure:" // STRING - asynq:failure: QueuePrefix = "asynq:queues:" // LIST - asynq:queues: @@ -45,6 +47,11 @@ func FailureKey(t time.Time) string { return failurePrefix + t.UTC().Format("2006-01-02") } +// ProcessStatusKey returns a redis key string for process status. +func ProcessStatusKey(hostname string, pid int) string { + return fmt.Sprintf("%s%s:%d", psPrefix, hostname, pid) +} + // TaskMessage is the internal representation of a task with additional metadata fields. // Serialized data of this type gets written to redis. type TaskMessage struct { @@ -69,3 +76,13 @@ type TaskMessage struct { // ErrorMsg holds the error message from the last failure. ErrorMsg string } + +// ProcessStatus holds information about running background worker process. +type ProcessStatus struct { + Concurrency int + Queues map[string]uint + PID int + Host string + State string + Started time.Time +} diff --git a/internal/base/base_test.go b/internal/base/base_test.go index 52624c4..79f88b4 100644 --- a/internal/base/base_test.go +++ b/internal/base/base_test.go @@ -60,3 +60,21 @@ func TestFailureKey(t *testing.T) { } } } + +func TestProcessStatusKey(t *testing.T) { + tests := []struct { + hostname string + pid int + want string + }{ + {"localhost", 9876, "asynq:ps:localhost:9876"}, + {"127.0.0.1", 1234, "asynq:ps:127.0.0.1:1234"}, + } + + for _, tc := range tests { + got := ProcessStatusKey(tc.hostname, tc.pid) + if got != tc.want { + t.Errorf("ProcessStatusKey(%s, %d) = %s, want %s", tc.hostname, tc.pid, got, tc.want) + } + } +} diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index a41a957..6403da4 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -346,3 +346,29 @@ func (r *RDB) forwardSingle(src, dst string) error { return script.Run(r.client, []string{src, dst}, now).Err() } + +// WriteProcessStatus writes process information to redis with expiration +// set to the value ttl. +func (r *RDB) WriteProcessStatus(ps *base.ProcessStatus, ttl time.Duration) error { + bytes, err := json.Marshal(ps) + if err != nil { + return err + } + key := base.ProcessStatusKey(ps.Host, ps.PID) + return r.client.Set(key, string(bytes), ttl).Err() +} + +// ReadProcessStatus reads process information stored in redis. +func (r *RDB) ReadProcessStatus(host string, pid int) (*base.ProcessStatus, error) { + key := base.ProcessStatusKey(host, pid) + data, err := r.client.Get(key).Result() + if err != nil { + return nil, err + } + var ps base.ProcessStatus + err = json.Unmarshal([]byte(data), &ps) + if err != nil { + return nil, err + } + return &ps, nil +} diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 9aee6e0..4cae1ef 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -738,3 +738,49 @@ func TestCheckAndEnqueue(t *testing.T) { } } } + +func TestReadWriteProcessStatus(t *testing.T) { + r := setup(t) + ps1 := &base.ProcessStatus{ + Concurrency: 10, + Queues: map[string]uint{"default": 2, "email": 5, "low": 1}, + PID: 98765, + Host: "localhost", + State: "running", + Started: time.Now(), + } + + tests := []struct { + ps *base.ProcessStatus + ttl time.Duration + }{ + {ps1, 5 * time.Second}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + err := r.WriteProcessStatus(tc.ps, tc.ttl) + if err != nil { + t.Errorf("r.WriteProcessStatus returned an error: %v", err) + continue + } + + got, err := r.ReadProcessStatus(tc.ps.Host, tc.ps.PID) + if err != nil { + t.Errorf("r.ReadProcessStatus returned an error: %v", err) + continue + } + + if diff := cmp.Diff(tc.ps, got); diff != "" { + t.Errorf("r.ReadProcessStatus(%q, %d) = %+v, want %+v; (-want,+got)\n%s", + tc.ps.Host, tc.ps.PID, got, tc.ps, diff) + } + + key := base.ProcessStatusKey(tc.ps.Host, tc.ps.PID) + gotTTL := r.client.TTL(key).Val() + if !cmp.Equal(tc.ttl, gotTTL, timeCmpOpt) { + t.Errorf("redis TTL %q returned %v, want %v", key, gotTTL, tc.ttl) + } + } +} diff --git a/syncer_test.go b/syncer_test.go index aa333c7..8eae321 100644 --- a/syncer_test.go +++ b/syncer_test.go @@ -87,6 +87,7 @@ func TestSyncerRetry(t *testing.T) { t.Errorf("%q has length %d; want %d", base.InProgressQueue, l, len(inProgress)) } + // FIXME: This assignment introduces data race and running the test with -race will fail. // simualate failover. rdbClient = rdb.NewRDB(goodClient)