mirror of
				https://github.com/hibiken/asynq.git
				synced 2025-10-25 10:56:12 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			924 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			924 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2020 Kentaro Hibino. All rights reserved.
 | |
| // Use of this source code is governed by a MIT license
 | |
| // that can be found in the LICENSE file.
 | |
| 
 | |
| package asynq
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"sort"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/google/go-cmp/cmp"
 | |
| 	"github.com/google/go-cmp/cmp/cmpopts"
 | |
| 	"github.com/hibiken/asynq/internal/base"
 | |
| 	"github.com/hibiken/asynq/internal/errors"
 | |
| 	"github.com/hibiken/asynq/internal/log"
 | |
| 	"github.com/hibiken/asynq/internal/rdb"
 | |
| 	h "github.com/hibiken/asynq/internal/testutil"
 | |
| 	"github.com/hibiken/asynq/internal/timeutil"
 | |
| )
 | |
| 
 | |
| var taskCmpOpts = []cmp.Option{
 | |
| 	sortTaskOpt,                               // sort the tasks
 | |
| 	cmp.AllowUnexported(Task{}),               // allow typename, payload fields to be compared
 | |
| 	cmpopts.IgnoreFields(Task{}, "opts", "w"), // ignore opts, w fields
 | |
| }
 | |
| 
 | |
| // fakeHeartbeater receives from starting and finished channels and do nothing.
 | |
| func fakeHeartbeater(starting <-chan *workerInfo, finished <-chan *base.TaskMessage, done <-chan struct{}) {
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-starting:
 | |
| 		case <-finished:
 | |
| 		case <-done:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // fakeSyncer receives from sync channel and do nothing.
 | |
| func fakeSyncer(syncCh <-chan *syncRequest, done <-chan struct{}) {
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-syncCh:
 | |
| 		case <-done:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Returns a processor instance configured for testing purpose.
 | |
| func newProcessorForTest(t *testing.T, r *rdb.RDB, h Handler) *processor {
 | |
| 	starting := make(chan *workerInfo)
 | |
| 	finished := make(chan *base.TaskMessage)
 | |
| 	syncCh := make(chan *syncRequest)
 | |
| 	done := make(chan struct{})
 | |
| 	t.Cleanup(func() { close(done) })
 | |
| 	go fakeHeartbeater(starting, finished, done)
 | |
| 	go fakeSyncer(syncCh, done)
 | |
| 	p := newProcessor(processorParams{
 | |
| 		logger:          testLogger,
 | |
| 		broker:          r,
 | |
| 		baseCtxFn:       context.Background,
 | |
| 		retryDelayFunc:  DefaultRetryDelayFunc,
 | |
| 		isFailureFunc:   defaultIsFailureFunc,
 | |
| 		syncCh:          syncCh,
 | |
| 		cancelations:    base.NewCancelations(),
 | |
| 		concurrency:     10,
 | |
| 		queues:          defaultQueueConfig,
 | |
| 		strictPriority:  false,
 | |
| 		errHandler:      nil,
 | |
| 		shutdownTimeout: defaultShutdownTimeout,
 | |
| 		starting:        starting,
 | |
| 		finished:        finished,
 | |
| 	})
 | |
| 	p.handler = h
 | |
| 	return p
 | |
| }
 | |
| 
 | |
| func TestProcessorSuccessWithSingleQueue(t *testing.T) {
 | |
| 	r := setup(t)
 | |
| 	defer r.Close()
 | |
| 	rdbClient := rdb.NewRDB(r)
 | |
| 
 | |
| 	m1 := h.NewTaskMessage("task1", nil)
 | |
| 	m2 := h.NewTaskMessage("task2", nil)
 | |
| 	m3 := h.NewTaskMessage("task3", nil)
 | |
| 	m4 := h.NewTaskMessage("task4", nil)
 | |
| 
 | |
| 	t1 := NewTask(m1.Type, m1.Payload)
 | |
| 	t2 := NewTask(m2.Type, m2.Payload)
 | |
| 	t3 := NewTask(m3.Type, m3.Payload)
 | |
| 	t4 := NewTask(m4.Type, m4.Payload)
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		pending       []*base.TaskMessage // initial default queue state
 | |
| 		incoming      []*base.TaskMessage // tasks to be enqueued during run
 | |
| 		wantProcessed []*Task             // tasks to be processed at the end
 | |
| 	}{
 | |
| 		{
 | |
| 			pending:       []*base.TaskMessage{m1},
 | |
| 			incoming:      []*base.TaskMessage{m2, m3, m4},
 | |
| 			wantProcessed: []*Task{t1, t2, t3, t4},
 | |
| 		},
 | |
| 		{
 | |
| 			pending:       []*base.TaskMessage{},
 | |
| 			incoming:      []*base.TaskMessage{m1},
 | |
| 			wantProcessed: []*Task{t1},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		h.FlushDB(t, r)                                             // clean up db before each test case.
 | |
| 		h.SeedPendingQueue(t, r, tc.pending, base.DefaultQueueName) // initialize default queue.
 | |
| 
 | |
| 		// instantiate a new processor
 | |
| 		var mu sync.Mutex
 | |
| 		var processed []*Task
 | |
| 		handler := func(ctx context.Context, task *Task) error {
 | |
| 			mu.Lock()
 | |
| 			defer mu.Unlock()
 | |
| 			processed = append(processed, task)
 | |
| 			return nil
 | |
| 		}
 | |
| 		p := newProcessorForTest(t, rdbClient, HandlerFunc(handler))
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		for _, msg := range tc.incoming {
 | |
| 			err := rdbClient.Enqueue(context.Background(), msg)
 | |
| 			if err != nil {
 | |
| 				p.shutdown()
 | |
| 				t.Fatal(err)
 | |
| 			}
 | |
| 		}
 | |
| 		time.Sleep(2 * time.Second) // wait for two second to allow all pending tasks to be processed.
 | |
| 		if l := r.LLen(context.Background(), base.ActiveKey(base.DefaultQueueName)).Val(); l != 0 {
 | |
| 			t.Errorf("%q has %d tasks, want 0", base.ActiveKey(base.DefaultQueueName), l)
 | |
| 		}
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		mu.Lock()
 | |
| 		if diff := cmp.Diff(tc.wantProcessed, processed, taskCmpOpts...); diff != "" {
 | |
| 			t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
 | |
| 		}
 | |
| 		mu.Unlock()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorSuccessWithMultipleQueues(t *testing.T) {
 | |
| 	var (
 | |
| 		r         = setup(t)
 | |
| 		rdbClient = rdb.NewRDB(r)
 | |
| 
 | |
| 		m1 = h.NewTaskMessage("task1", nil)
 | |
| 		m2 = h.NewTaskMessage("task2", nil)
 | |
| 		m3 = h.NewTaskMessageWithQueue("task3", nil, "high")
 | |
| 		m4 = h.NewTaskMessageWithQueue("task4", nil, "low")
 | |
| 
 | |
| 		t1 = NewTask(m1.Type, m1.Payload)
 | |
| 		t2 = NewTask(m2.Type, m2.Payload)
 | |
| 		t3 = NewTask(m3.Type, m3.Payload)
 | |
| 		t4 = NewTask(m4.Type, m4.Payload)
 | |
| 	)
 | |
| 	defer r.Close()
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		pending       map[string][]*base.TaskMessage
 | |
| 		queues        []string // list of queues to consume the tasks from
 | |
| 		wantProcessed []*Task  // tasks to be processed at the end
 | |
| 	}{
 | |
| 		{
 | |
| 			pending: map[string][]*base.TaskMessage{
 | |
| 				"default": {m1, m2},
 | |
| 				"high":    {m3},
 | |
| 				"low":     {m4},
 | |
| 			},
 | |
| 			queues:        []string{"default", "high", "low"},
 | |
| 			wantProcessed: []*Task{t1, t2, t3, t4},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		// Set up test case.
 | |
| 		h.FlushDB(t, r)
 | |
| 		h.SeedAllPendingQueues(t, r, tc.pending)
 | |
| 
 | |
| 		// Instantiate a new processor.
 | |
| 		var mu sync.Mutex
 | |
| 		var processed []*Task
 | |
| 		handler := func(ctx context.Context, task *Task) error {
 | |
| 			mu.Lock()
 | |
| 			defer mu.Unlock()
 | |
| 			processed = append(processed, task)
 | |
| 			return nil
 | |
| 		}
 | |
| 		p := newProcessorForTest(t, rdbClient, HandlerFunc(handler))
 | |
| 		p.queueConfig = map[string]int{
 | |
| 			"default": 2,
 | |
| 			"high":    3,
 | |
| 			"low":     1,
 | |
| 		}
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		// Wait for two second to allow all pending tasks to be processed.
 | |
| 		time.Sleep(2 * time.Second)
 | |
| 		// Make sure no messages are stuck in active list.
 | |
| 		for _, qname := range tc.queues {
 | |
| 			if l := r.LLen(context.Background(), base.ActiveKey(qname)).Val(); l != 0 {
 | |
| 				t.Errorf("%q has %d tasks, want 0", base.ActiveKey(qname), l)
 | |
| 			}
 | |
| 		}
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		mu.Lock()
 | |
| 		if diff := cmp.Diff(tc.wantProcessed, processed, taskCmpOpts...); diff != "" {
 | |
| 			t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
 | |
| 		}
 | |
| 		mu.Unlock()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // https://github.com/hibiken/asynq/issues/166
 | |
| func TestProcessTasksWithLargeNumberInPayload(t *testing.T) {
 | |
| 	r := setup(t)
 | |
| 	defer r.Close()
 | |
| 	rdbClient := rdb.NewRDB(r)
 | |
| 
 | |
| 	m1 := h.NewTaskMessage("large_number", h.JSON(map[string]interface{}{"data": 111111111111111111}))
 | |
| 	t1 := NewTask(m1.Type, m1.Payload)
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		pending       []*base.TaskMessage // initial default queue state
 | |
| 		wantProcessed []*Task             // tasks to be processed at the end
 | |
| 	}{
 | |
| 		{
 | |
| 			pending:       []*base.TaskMessage{m1},
 | |
| 			wantProcessed: []*Task{t1},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		h.FlushDB(t, r)                                             // clean up db before each test case.
 | |
| 		h.SeedPendingQueue(t, r, tc.pending, base.DefaultQueueName) // initialize default queue.
 | |
| 
 | |
| 		var mu sync.Mutex
 | |
| 		var processed []*Task
 | |
| 		handler := func(ctx context.Context, task *Task) error {
 | |
| 			mu.Lock()
 | |
| 			defer mu.Unlock()
 | |
| 			var payload map[string]int
 | |
| 			if err := json.Unmarshal(task.Payload(), &payload); err != nil {
 | |
| 				t.Errorf("coult not decode payload: %v", err)
 | |
| 			}
 | |
| 			if data, ok := payload["data"]; ok {
 | |
| 				t.Logf("data == %d", data)
 | |
| 			} else {
 | |
| 				t.Errorf("could not get data from payload")
 | |
| 			}
 | |
| 			processed = append(processed, task)
 | |
| 			return nil
 | |
| 		}
 | |
| 		p := newProcessorForTest(t, rdbClient, HandlerFunc(handler))
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		time.Sleep(2 * time.Second) // wait for two second to allow all pending tasks to be processed.
 | |
| 		if l := r.LLen(context.Background(), base.ActiveKey(base.DefaultQueueName)).Val(); l != 0 {
 | |
| 			t.Errorf("%q has %d tasks, want 0", base.ActiveKey(base.DefaultQueueName), l)
 | |
| 		}
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		mu.Lock()
 | |
| 		if diff := cmp.Diff(tc.wantProcessed, processed, taskCmpOpts...); diff != "" {
 | |
| 			t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
 | |
| 		}
 | |
| 		mu.Unlock()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorRetry(t *testing.T) {
 | |
| 	r := setup(t)
 | |
| 	defer r.Close()
 | |
| 	rdbClient := rdb.NewRDB(r)
 | |
| 
 | |
| 	m1 := h.NewTaskMessage("send_email", nil)
 | |
| 	m1.Retried = m1.Retry // m1 has reached its max retry count
 | |
| 	m2 := h.NewTaskMessage("gen_thumbnail", nil)
 | |
| 	m3 := h.NewTaskMessage("reindex", nil)
 | |
| 	m4 := h.NewTaskMessage("sync", nil)
 | |
| 
 | |
| 	errMsg := "something went wrong"
 | |
| 	wrappedSkipRetry := fmt.Errorf("%s:%w", errMsg, SkipRetry)
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		desc         string              // test description
 | |
| 		pending      []*base.TaskMessage // initial default queue state
 | |
| 		delay        time.Duration       // retry delay duration
 | |
| 		handler      Handler             // task handler
 | |
| 		wait         time.Duration       // wait duration between starting and stopping processor for this test case
 | |
| 		wantErrMsg   string              // error message the task should record
 | |
| 		wantRetry    []*base.TaskMessage // tasks in retry queue at the end
 | |
| 		wantArchived []*base.TaskMessage // tasks in archived queue at the end
 | |
| 		wantErrCount int                 // number of times error handler should be called
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:    "Should automatically retry errored tasks",
 | |
| 			pending: []*base.TaskMessage{m1, m2, m3, m4},
 | |
| 			delay:   time.Minute,
 | |
| 			handler: HandlerFunc(func(ctx context.Context, task *Task) error {
 | |
| 				return fmt.Errorf(errMsg)
 | |
| 			}),
 | |
| 			wait:         2 * time.Second,
 | |
| 			wantErrMsg:   errMsg,
 | |
| 			wantRetry:    []*base.TaskMessage{m2, m3, m4},
 | |
| 			wantArchived: []*base.TaskMessage{m1},
 | |
| 			wantErrCount: 4,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Should skip retry errored tasks",
 | |
| 			pending: []*base.TaskMessage{m1, m2},
 | |
| 			delay:   time.Minute,
 | |
| 			handler: HandlerFunc(func(ctx context.Context, task *Task) error {
 | |
| 				return SkipRetry // return SkipRetry without wrapping
 | |
| 			}),
 | |
| 			wait:         2 * time.Second,
 | |
| 			wantErrMsg:   SkipRetry.Error(),
 | |
| 			wantRetry:    []*base.TaskMessage{},
 | |
| 			wantArchived: []*base.TaskMessage{m1, m2},
 | |
| 			wantErrCount: 2, // ErrorHandler should still be called with SkipRetry error
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Should skip retry errored tasks (with error wrapping)",
 | |
| 			pending: []*base.TaskMessage{m1, m2},
 | |
| 			delay:   time.Minute,
 | |
| 			handler: HandlerFunc(func(ctx context.Context, task *Task) error {
 | |
| 				return wrappedSkipRetry
 | |
| 			}),
 | |
| 			wait:         2 * time.Second,
 | |
| 			wantErrMsg:   wrappedSkipRetry.Error(),
 | |
| 			wantRetry:    []*base.TaskMessage{},
 | |
| 			wantArchived: []*base.TaskMessage{m1, m2},
 | |
| 			wantErrCount: 2, // ErrorHandler should still be called with SkipRetry error
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		h.FlushDB(t, r)                                             // clean up db before each test case.
 | |
| 		h.SeedPendingQueue(t, r, tc.pending, base.DefaultQueueName) // initialize default queue.
 | |
| 
 | |
| 		// instantiate a new processor
 | |
| 		delayFunc := func(n int, e error, t *Task) time.Duration {
 | |
| 			return tc.delay
 | |
| 		}
 | |
| 		var (
 | |
| 			mu sync.Mutex // guards n
 | |
| 			n  int        // number of times error handler is called
 | |
| 		)
 | |
| 		errHandler := func(ctx context.Context, t *Task, err error) {
 | |
| 			mu.Lock()
 | |
| 			defer mu.Unlock()
 | |
| 			n++
 | |
| 		}
 | |
| 		p := newProcessorForTest(t, rdbClient, tc.handler)
 | |
| 		p.errHandler = ErrorHandlerFunc(errHandler)
 | |
| 		p.retryDelayFunc = delayFunc
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		runTime := time.Now() // time when processor is running
 | |
| 		time.Sleep(tc.wait)   // FIXME: This makes test flaky.
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		cmpOpt := h.EquateInt64Approx(int64(tc.wait.Seconds())) // allow up to a wait-second difference in zset score
 | |
| 		gotRetry := h.GetRetryEntries(t, r, base.DefaultQueueName)
 | |
| 		var wantRetry []base.Z // Note: construct wantRetry here since `LastFailedAt` and ZSCORE is relative to each test run.
 | |
| 		for _, msg := range tc.wantRetry {
 | |
| 			wantRetry = append(wantRetry,
 | |
| 				base.Z{
 | |
| 					Message: h.TaskMessageAfterRetry(*msg, tc.wantErrMsg, runTime),
 | |
| 					Score:   runTime.Add(tc.delay).Unix(),
 | |
| 				})
 | |
| 		}
 | |
| 		if diff := cmp.Diff(wantRetry, gotRetry, h.SortZSetEntryOpt, cmpOpt); diff != "" {
 | |
| 			t.Errorf("%s: mismatch found in %q after running processor; (-want, +got)\n%s", tc.desc, base.RetryKey(base.DefaultQueueName), diff)
 | |
| 		}
 | |
| 
 | |
| 		gotArchived := h.GetArchivedEntries(t, r, base.DefaultQueueName)
 | |
| 		var wantArchived []base.Z // Note: construct wantArchived here since `LastFailedAt` and ZSCORE is relative to each test run.
 | |
| 		for _, msg := range tc.wantArchived {
 | |
| 			wantArchived = append(wantArchived,
 | |
| 				base.Z{
 | |
| 					Message: h.TaskMessageWithError(*msg, tc.wantErrMsg, runTime),
 | |
| 					Score:   runTime.Unix(),
 | |
| 				})
 | |
| 		}
 | |
| 		if diff := cmp.Diff(wantArchived, gotArchived, h.SortZSetEntryOpt, cmpOpt); diff != "" {
 | |
| 			t.Errorf("%s: mismatch found in %q after running processor; (-want, +got)\n%s", tc.desc, base.ArchivedKey(base.DefaultQueueName), diff)
 | |
| 		}
 | |
| 
 | |
| 		if l := r.LLen(context.Background(), base.ActiveKey(base.DefaultQueueName)).Val(); l != 0 {
 | |
| 			t.Errorf("%s: %q has %d tasks, want 0", base.ActiveKey(base.DefaultQueueName), tc.desc, l)
 | |
| 		}
 | |
| 
 | |
| 		if n != tc.wantErrCount {
 | |
| 			t.Errorf("error handler was called %d times, want %d", n, tc.wantErrCount)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorMarkAsComplete(t *testing.T) {
 | |
| 	r := setup(t)
 | |
| 	defer r.Close()
 | |
| 	rdbClient := rdb.NewRDB(r)
 | |
| 
 | |
| 	msg1 := h.NewTaskMessage("one", nil)
 | |
| 	msg2 := h.NewTaskMessage("two", nil)
 | |
| 	msg3 := h.NewTaskMessageWithQueue("three", nil, "custom")
 | |
| 	msg1.Retention = 3600
 | |
| 	msg3.Retention = 7200
 | |
| 
 | |
| 	handler := func(ctx context.Context, task *Task) error { return nil }
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		pending       map[string][]*base.TaskMessage
 | |
| 		completed     map[string][]base.Z
 | |
| 		queueCfg      map[string]int
 | |
| 		wantPending   map[string][]*base.TaskMessage
 | |
| 		wantCompleted func(completedAt time.Time) map[string][]base.Z
 | |
| 	}{
 | |
| 		{
 | |
| 			pending: map[string][]*base.TaskMessage{
 | |
| 				"default": {msg1, msg2},
 | |
| 				"custom":  {msg3},
 | |
| 			},
 | |
| 			completed: map[string][]base.Z{
 | |
| 				"default": {},
 | |
| 				"custom":  {},
 | |
| 			},
 | |
| 			queueCfg: map[string]int{
 | |
| 				"default": 1,
 | |
| 				"custom":  1,
 | |
| 			},
 | |
| 			wantPending: map[string][]*base.TaskMessage{
 | |
| 				"default": {},
 | |
| 				"custom":  {},
 | |
| 			},
 | |
| 			wantCompleted: func(completedAt time.Time) map[string][]base.Z {
 | |
| 				return map[string][]base.Z{
 | |
| 					"default": {{Message: h.TaskMessageWithCompletedAt(*msg1, completedAt), Score: completedAt.Unix() + msg1.Retention}},
 | |
| 					"custom":  {{Message: h.TaskMessageWithCompletedAt(*msg3, completedAt), Score: completedAt.Unix() + msg3.Retention}},
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		h.FlushDB(t, r)
 | |
| 		h.SeedAllPendingQueues(t, r, tc.pending)
 | |
| 		h.SeedAllCompletedQueues(t, r, tc.completed)
 | |
| 
 | |
| 		p := newProcessorForTest(t, rdbClient, HandlerFunc(handler))
 | |
| 		p.queueConfig = tc.queueCfg
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		runTime := time.Now() // time when processor is running
 | |
| 		time.Sleep(2 * time.Second)
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		for qname, want := range tc.wantPending {
 | |
| 			gotPending := h.GetPendingMessages(t, r, qname)
 | |
| 			if diff := cmp.Diff(want, gotPending, cmpopts.EquateEmpty()); diff != "" {
 | |
| 				t.Errorf("diff found in %q pending set; want=%v, got=%v\n%s", qname, want, gotPending, diff)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		for qname, want := range tc.wantCompleted(runTime) {
 | |
| 			gotCompleted := h.GetCompletedEntries(t, r, qname)
 | |
| 			if diff := cmp.Diff(want, gotCompleted, cmpopts.EquateEmpty()); diff != "" {
 | |
| 				t.Errorf("diff found in %q completed set; want=%v, got=%v\n%s", qname, want, gotCompleted, diff)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Test a scenario where the worker server cannot communicate with redis due to a network failure
 | |
| // and the lease expires
 | |
| func TestProcessorWithExpiredLease(t *testing.T) {
 | |
| 	r := setup(t)
 | |
| 	defer r.Close()
 | |
| 	rdbClient := rdb.NewRDB(r)
 | |
| 
 | |
| 	m1 := h.NewTaskMessage("task1", nil)
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		pending      []*base.TaskMessage
 | |
| 		handler      Handler
 | |
| 		wantErrCount int
 | |
| 	}{
 | |
| 		{
 | |
| 			pending: []*base.TaskMessage{m1},
 | |
| 			handler: HandlerFunc(func(ctx context.Context, task *Task) error {
 | |
| 				// make sure the task processing time exceeds lease duration
 | |
| 				// to test expired lease.
 | |
| 				time.Sleep(rdb.LeaseDuration + 10*time.Second)
 | |
| 				return nil
 | |
| 			}),
 | |
| 			wantErrCount: 1, // ErrorHandler should still be called with ErrLeaseExpired
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		h.FlushDB(t, r)
 | |
| 		h.SeedPendingQueue(t, r, tc.pending, base.DefaultQueueName)
 | |
| 
 | |
| 		starting := make(chan *workerInfo)
 | |
| 		finished := make(chan *base.TaskMessage)
 | |
| 		syncCh := make(chan *syncRequest)
 | |
| 		done := make(chan struct{})
 | |
| 		t.Cleanup(func() { close(done) })
 | |
| 		// fake heartbeater which notifies lease expiration
 | |
| 		go func() {
 | |
| 			for {
 | |
| 				select {
 | |
| 				case w := <-starting:
 | |
| 					// simulate expiration by resetting to some time in the past
 | |
| 					w.lease.Reset(time.Now().Add(-5 * time.Second))
 | |
| 					if !w.lease.NotifyExpiration() {
 | |
| 						panic("Failed to notifiy lease expiration")
 | |
| 					}
 | |
| 				case <-finished:
 | |
| 					// do nothing
 | |
| 				case <-done:
 | |
| 					return
 | |
| 				}
 | |
| 			}
 | |
| 		}()
 | |
| 		go fakeSyncer(syncCh, done)
 | |
| 		p := newProcessor(processorParams{
 | |
| 			logger:          testLogger,
 | |
| 			broker:          rdbClient,
 | |
| 			baseCtxFn:       context.Background,
 | |
| 			retryDelayFunc:  DefaultRetryDelayFunc,
 | |
| 			isFailureFunc:   defaultIsFailureFunc,
 | |
| 			syncCh:          syncCh,
 | |
| 			cancelations:    base.NewCancelations(),
 | |
| 			concurrency:     10,
 | |
| 			queues:          defaultQueueConfig,
 | |
| 			strictPriority:  false,
 | |
| 			errHandler:      nil,
 | |
| 			shutdownTimeout: defaultShutdownTimeout,
 | |
| 			starting:        starting,
 | |
| 			finished:        finished,
 | |
| 		})
 | |
| 		p.handler = tc.handler
 | |
| 		var (
 | |
| 			mu   sync.Mutex // guards n and errs
 | |
| 			n    int        // number of times error handler is called
 | |
| 			errs []error    // error passed to error handler
 | |
| 		)
 | |
| 		p.errHandler = ErrorHandlerFunc(func(ctx context.Context, t *Task, err error) {
 | |
| 			mu.Lock()
 | |
| 			defer mu.Unlock()
 | |
| 			n++
 | |
| 			errs = append(errs, err)
 | |
| 		})
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		time.Sleep(4 * time.Second)
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		if n != tc.wantErrCount {
 | |
| 			t.Errorf("Unexpected number of error count: got %d, want %d", n, tc.wantErrCount)
 | |
| 			continue
 | |
| 		}
 | |
| 		for i := 0; i < tc.wantErrCount; i++ {
 | |
| 			if !errors.Is(errs[i], ErrLeaseExpired) {
 | |
| 				t.Errorf("Unexpected error was passed to ErrorHandler: got %v want %v", errs[i], ErrLeaseExpired)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorQueues(t *testing.T) {
 | |
| 	sortOpt := cmp.Transformer("SortStrings", func(in []string) []string {
 | |
| 		out := append([]string(nil), in...) // Copy input to avoid mutating it
 | |
| 		sort.Strings(out)
 | |
| 		return out
 | |
| 	})
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		queueCfg map[string]int
 | |
| 		want     []string
 | |
| 	}{
 | |
| 		{
 | |
| 			queueCfg: map[string]int{
 | |
| 				"high":    6,
 | |
| 				"default": 3,
 | |
| 				"low":     1,
 | |
| 			},
 | |
| 			want: []string{"high", "default", "low"},
 | |
| 		},
 | |
| 		{
 | |
| 			queueCfg: map[string]int{
 | |
| 				"default": 1,
 | |
| 			},
 | |
| 			want: []string{"default"},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		// Note: rdb and handler not needed for this test.
 | |
| 		p := newProcessorForTest(t, nil, nil)
 | |
| 		p.queueConfig = tc.queueCfg
 | |
| 
 | |
| 		got := p.queues()
 | |
| 		if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" {
 | |
| 			t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s",
 | |
| 				tc.queueCfg, got, tc.want, diff)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorWithStrictPriority(t *testing.T) {
 | |
| 	var (
 | |
| 		r = setup(t)
 | |
| 
 | |
| 		rdbClient = rdb.NewRDB(r)
 | |
| 
 | |
| 		m1 = h.NewTaskMessageWithQueue("task1", nil, "critical")
 | |
| 		m2 = h.NewTaskMessageWithQueue("task2", nil, "critical")
 | |
| 		m3 = h.NewTaskMessageWithQueue("task3", nil, "critical")
 | |
| 		m4 = h.NewTaskMessageWithQueue("task4", nil, base.DefaultQueueName)
 | |
| 		m5 = h.NewTaskMessageWithQueue("task5", nil, base.DefaultQueueName)
 | |
| 		m6 = h.NewTaskMessageWithQueue("task6", nil, "low")
 | |
| 		m7 = h.NewTaskMessageWithQueue("task7", nil, "low")
 | |
| 
 | |
| 		t1 = NewTask(m1.Type, m1.Payload)
 | |
| 		t2 = NewTask(m2.Type, m2.Payload)
 | |
| 		t3 = NewTask(m3.Type, m3.Payload)
 | |
| 		t4 = NewTask(m4.Type, m4.Payload)
 | |
| 		t5 = NewTask(m5.Type, m5.Payload)
 | |
| 		t6 = NewTask(m6.Type, m6.Payload)
 | |
| 		t7 = NewTask(m7.Type, m7.Payload)
 | |
| 	)
 | |
| 	defer r.Close()
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		pending       map[string][]*base.TaskMessage // initial queues state
 | |
| 		queues        []string                       // list of queues to consume tasks from
 | |
| 		wait          time.Duration                  // wait duration between starting and stopping processor for this test case
 | |
| 		wantProcessed []*Task                        // tasks to be processed at the end
 | |
| 	}{
 | |
| 		{
 | |
| 			pending: map[string][]*base.TaskMessage{
 | |
| 				base.DefaultQueueName: {m4, m5},
 | |
| 				"critical":            {m1, m2, m3},
 | |
| 				"low":                 {m6, m7},
 | |
| 			},
 | |
| 			queues:        []string{base.DefaultQueueName, "critical", "low"},
 | |
| 			wait:          time.Second,
 | |
| 			wantProcessed: []*Task{t1, t2, t3, t4, t5, t6, t7},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		h.FlushDB(t, r) // clean up db before each test case.
 | |
| 		for qname, msgs := range tc.pending {
 | |
| 			h.SeedPendingQueue(t, r, msgs, qname)
 | |
| 		}
 | |
| 
 | |
| 		// instantiate a new processor
 | |
| 		var mu sync.Mutex
 | |
| 		var processed []*Task
 | |
| 		handler := func(ctx context.Context, task *Task) error {
 | |
| 			mu.Lock()
 | |
| 			defer mu.Unlock()
 | |
| 			processed = append(processed, task)
 | |
| 			return nil
 | |
| 		}
 | |
| 		queueCfg := map[string]int{
 | |
| 			base.DefaultQueueName: 2,
 | |
| 			"critical":            3,
 | |
| 			"low":                 1,
 | |
| 		}
 | |
| 		starting := make(chan *workerInfo)
 | |
| 		finished := make(chan *base.TaskMessage)
 | |
| 		syncCh := make(chan *syncRequest)
 | |
| 		done := make(chan struct{})
 | |
| 		defer func() { close(done) }()
 | |
| 		go fakeHeartbeater(starting, finished, done)
 | |
| 		go fakeSyncer(syncCh, done)
 | |
| 		p := newProcessor(processorParams{
 | |
| 			logger:          testLogger,
 | |
| 			broker:          rdbClient,
 | |
| 			baseCtxFn:       context.Background,
 | |
| 			retryDelayFunc:  DefaultRetryDelayFunc,
 | |
| 			isFailureFunc:   defaultIsFailureFunc,
 | |
| 			syncCh:          syncCh,
 | |
| 			cancelations:    base.NewCancelations(),
 | |
| 			concurrency:     1, // Set concurrency to 1 to make sure tasks are processed one at a time.
 | |
| 			queues:          queueCfg,
 | |
| 			strictPriority:  true,
 | |
| 			errHandler:      nil,
 | |
| 			shutdownTimeout: defaultShutdownTimeout,
 | |
| 			starting:        starting,
 | |
| 			finished:        finished,
 | |
| 		})
 | |
| 		p.handler = HandlerFunc(handler)
 | |
| 
 | |
| 		p.start(&sync.WaitGroup{})
 | |
| 		time.Sleep(tc.wait)
 | |
| 		// Make sure no tasks are stuck in active list.
 | |
| 		for _, qname := range tc.queues {
 | |
| 			if l := r.LLen(context.Background(), base.ActiveKey(qname)).Val(); l != 0 {
 | |
| 				t.Errorf("%q has %d tasks, want 0", base.ActiveKey(qname), l)
 | |
| 			}
 | |
| 		}
 | |
| 		p.shutdown()
 | |
| 
 | |
| 		if diff := cmp.Diff(tc.wantProcessed, processed, taskCmpOpts...); diff != "" {
 | |
| 			t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
 | |
| 		}
 | |
| 
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorPerform(t *testing.T) {
 | |
| 	tests := []struct {
 | |
| 		desc    string
 | |
| 		handler HandlerFunc
 | |
| 		task    *Task
 | |
| 		wantErr bool
 | |
| 	}{
 | |
| 		{
 | |
| 			desc: "handler returns nil",
 | |
| 			handler: func(ctx context.Context, t *Task) error {
 | |
| 				return nil
 | |
| 			},
 | |
| 			task:    NewTask("gen_thumbnail", h.JSON(map[string]interface{}{"src": "some/img/path"})),
 | |
| 			wantErr: false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "handler returns error",
 | |
| 			handler: func(ctx context.Context, t *Task) error {
 | |
| 				return fmt.Errorf("something went wrong")
 | |
| 			},
 | |
| 			task:    NewTask("gen_thumbnail", h.JSON(map[string]interface{}{"src": "some/img/path"})),
 | |
| 			wantErr: true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "handler panics",
 | |
| 			handler: func(ctx context.Context, t *Task) error {
 | |
| 				panic("something went terribly wrong")
 | |
| 			},
 | |
| 			task:    NewTask("gen_thumbnail", h.JSON(map[string]interface{}{"src": "some/img/path"})),
 | |
| 			wantErr: true,
 | |
| 		},
 | |
| 	}
 | |
| 	// Note: We don't need to fully initialized the processor since we are only testing
 | |
| 	// perform method.
 | |
| 	p := newProcessorForTest(t, nil, nil)
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		p.handler = tc.handler
 | |
| 		got := p.perform(context.Background(), tc.task)
 | |
| 		if !tc.wantErr && got != nil {
 | |
| 			t.Errorf("%s: perform() = %v, want nil", tc.desc, got)
 | |
| 			continue
 | |
| 		}
 | |
| 		if tc.wantErr && got == nil {
 | |
| 			t.Errorf("%s: perform() = nil, want non-nil error", tc.desc)
 | |
| 			continue
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGCD(t *testing.T) {
 | |
| 	tests := []struct {
 | |
| 		input []int
 | |
| 		want  int
 | |
| 	}{
 | |
| 		{[]int{6, 2, 12}, 2},
 | |
| 		{[]int{3, 3, 3}, 3},
 | |
| 		{[]int{6, 3, 1}, 1},
 | |
| 		{[]int{1}, 1},
 | |
| 		{[]int{1, 0, 2}, 1},
 | |
| 		{[]int{8, 0, 4}, 4},
 | |
| 		{[]int{9, 12, 18, 30}, 3},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		got := gcd(tc.input...)
 | |
| 		if got != tc.want {
 | |
| 			t.Errorf("gcd(%v) = %d, want %d", tc.input, got, tc.want)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestNormalizeQueues(t *testing.T) {
 | |
| 	tests := []struct {
 | |
| 		input map[string]int
 | |
| 		want  map[string]int
 | |
| 	}{
 | |
| 		{
 | |
| 			input: map[string]int{
 | |
| 				"high":    100,
 | |
| 				"default": 20,
 | |
| 				"low":     5,
 | |
| 			},
 | |
| 			want: map[string]int{
 | |
| 				"high":    20,
 | |
| 				"default": 4,
 | |
| 				"low":     1,
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			input: map[string]int{
 | |
| 				"default": 10,
 | |
| 			},
 | |
| 			want: map[string]int{
 | |
| 				"default": 1,
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			input: map[string]int{
 | |
| 				"critical": 5,
 | |
| 				"default":  1,
 | |
| 			},
 | |
| 			want: map[string]int{
 | |
| 				"critical": 5,
 | |
| 				"default":  1,
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			input: map[string]int{
 | |
| 				"critical": 6,
 | |
| 				"default":  3,
 | |
| 				"low":      0,
 | |
| 			},
 | |
| 			want: map[string]int{
 | |
| 				"critical": 2,
 | |
| 				"default":  1,
 | |
| 				"low":      0,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		got := normalizeQueues(tc.input)
 | |
| 		if diff := cmp.Diff(tc.want, got); diff != "" {
 | |
| 			t.Errorf("normalizeQueues(%v) = %v, want %v; (-want, +got):\n%s",
 | |
| 				tc.input, got, tc.want, diff)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessorComputeDeadline(t *testing.T) {
 | |
| 	now := time.Now()
 | |
| 	p := processor{
 | |
| 		logger: log.NewLogger(nil),
 | |
| 		clock:  timeutil.NewSimulatedClock(now),
 | |
| 	}
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		desc string
 | |
| 		msg  *base.TaskMessage
 | |
| 		want time.Time
 | |
| 	}{
 | |
| 		{
 | |
| 			desc: "message with only timeout specified",
 | |
| 			msg: &base.TaskMessage{
 | |
| 				Timeout: int64((30 * time.Minute).Seconds()),
 | |
| 			},
 | |
| 			want: now.Add(30 * time.Minute),
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "message with only deadline specified",
 | |
| 			msg: &base.TaskMessage{
 | |
| 				Deadline: now.Add(24 * time.Hour).Unix(),
 | |
| 			},
 | |
| 			want: now.Add(24 * time.Hour),
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "message with both timeout and deadline set (now+timeout < deadline)",
 | |
| 			msg: &base.TaskMessage{
 | |
| 				Deadline: now.Add(24 * time.Hour).Unix(),
 | |
| 				Timeout:  int64((30 * time.Minute).Seconds()),
 | |
| 			},
 | |
| 			want: now.Add(30 * time.Minute),
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "message with both timeout and deadline set (now+timeout > deadline)",
 | |
| 			msg: &base.TaskMessage{
 | |
| 				Deadline: now.Add(10 * time.Minute).Unix(),
 | |
| 				Timeout:  int64((30 * time.Minute).Seconds()),
 | |
| 			},
 | |
| 			want: now.Add(10 * time.Minute),
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "message with both timeout and deadline set (now+timeout == deadline)",
 | |
| 			msg: &base.TaskMessage{
 | |
| 				Deadline: now.Add(30 * time.Minute).Unix(),
 | |
| 				Timeout:  int64((30 * time.Minute).Seconds()),
 | |
| 			},
 | |
| 			want: now.Add(30 * time.Minute),
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "message without timeout and deadline",
 | |
| 			msg:  &base.TaskMessage{},
 | |
| 			want: now.Add(defaultTimeout),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tests {
 | |
| 		got := p.computeDeadline(tc.msg)
 | |
| 		// Compare the Unix epoch with seconds granularity
 | |
| 		if got.Unix() != tc.want.Unix() {
 | |
| 			t.Errorf("%s: got=%v, want=%v", tc.desc, got.Unix(), tc.want.Unix())
 | |
| 		}
 | |
| 	}
 | |
| }
 |