diff --git a/client.go b/client.go index 71bc154..b531bb8 100644 --- a/client.go +++ b/client.go @@ -341,6 +341,9 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { // // The first argument context applies to the enqueue operation. To specify task timeout and deadline, use Timeout and Deadline option instead. func (c *Client) EnqueueContext(ctx context.Context, task *Task, opts ...Option) (*TaskInfo, error) { + if task == nil { + return nil, fmt.Errorf("task cannot be nil") + } if strings.TrimSpace(task.Type()) == "" { return nil, fmt.Errorf("task typename cannot be empty") } diff --git a/client_test.go b/client_test.go index d305d76..2d4330e 100644 --- a/client_test.go +++ b/client_test.go @@ -857,6 +857,11 @@ func TestClientEnqueueError(t *testing.T) { task *Task opts []Option }{ + { + desc: "With nil task", + task: nil, + opts: []Option{}, + }, { desc: "With empty queue name", task: task, diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index ce35866..5d228d5 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -234,3 +234,57 @@ func (tb *TestBroker) Close() error { } return tb.real.Close() } + +func (tb *TestBroker) AddToGroup(ctx context.Context, msg *base.TaskMessage, gname string) error { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return errRedisDown + } + return tb.real.AddToGroup(ctx, msg, gname) +} + +func (tb *TestBroker) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, gname string, ttl time.Duration) error { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return errRedisDown + } + return tb.real.AddToGroupUnique(ctx, msg, gname, ttl) +} + +func (tb *TestBroker) ListGroups(qname string) ([]string, error) { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return nil, errRedisDown + } + return tb.real.ListGroups(qname) +} + +func (tb *TestBroker) AggregationCheck(qname, gname string, t time.Time, gracePeriod, maxDelay time.Duration, maxSize int) (aggregationSetID string, err error) { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return "", errRedisDown + } + return tb.real.AggregationCheck(qname, gname, t, gracePeriod, maxDelay, maxSize) +} + +func (tb *TestBroker) ReadAggregationSet(qname, gname, aggregationSetID string) ([]*base.TaskMessage, time.Time, error) { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return nil, time.Time{}, errRedisDown + } + return tb.real.ReadAggregationSet(qname, gname, aggregationSetID) +} + +func (tb *TestBroker) DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return errRedisDown + } + return tb.real.DeleteAggregationSet(ctx, qname, gname, aggregationSetID) +}