From f7f6c93271cb23bf1a54f5298fda7f766cd84969 Mon Sep 17 00:00:00 2001 From: ajatprabha Date: Sun, 31 Oct 2021 21:53:38 +0530 Subject: [PATCH] separate out error test cases --- x/rate/semaphore_test.go | 214 ++++++++++++++++++++++++++------------- 1 file changed, 145 insertions(+), 69 deletions(-) diff --git a/x/rate/semaphore_test.go b/x/rate/semaphore_test.go index e9ee484..6273687 100644 --- a/x/rate/semaphore_test.go +++ b/x/rate/semaphore_test.go @@ -83,10 +83,9 @@ func TestNewSemaphore_Acquire(t *testing.T) { taskIDs []uuid.UUID ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) want []bool - wantErr string }{ { - desc: "Should acquire lock when current lock count is less than maxTokens", + desc: "Should acquire token when current token count is less than maxTokens", name: "task-1", maxConcurrency: 3, taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, @@ -99,7 +98,7 @@ func TestNewSemaphore_Acquire(t *testing.T) { want: []bool{true, true}, }, { - desc: "Should fail acquiring lock when current lock count is equal to maxTokens", + desc: "Should fail acquiring token when current token count is equal to maxTokens", name: "task-2", maxConcurrency: 3, taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()}, @@ -111,28 +110,6 @@ func TestNewSemaphore_Acquire(t *testing.T) { }, want: []bool{true, true, true, false}, }, - { - desc: "Should return error if context has no deadline", - name: "task-3", - maxConcurrency: 1, - taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, - ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { - return context.Background(), func() {} - }, - want: []bool{false, false}, - wantErr: "provided context must have a deadline", - }, - { - desc: "Should return error when context is missing taskID", - name: "task-4", - maxConcurrency: 1, - taskIDs: []uuid.UUID{uuid.New()}, - ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), time.Second) - }, - want: []bool{false}, - wantErr: "provided context is missing task ID value", - }, } for _, tt := range tests { @@ -152,12 +129,13 @@ func TestNewSemaphore_Acquire(t *testing.T) { ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) got, err := sema.Acquire(ctx) + if err != nil { + t.Errorf("%s;\nSemaphore.Acquire() got error %v", tt.desc, err) + } + if got != tt.want[i] { t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, tt.want[i]) } - if (tt.wantErr == "" && err != nil) || (tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr)) { - t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.wantErr) - } cancel() } @@ -165,19 +143,79 @@ func TestNewSemaphore_Acquire(t *testing.T) { } } -func TestNewSemaphore_Acquire_StaleLock(t *testing.T) { +func TestNewSemaphore_Acquire_Error(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + errStr string + }{ + { + desc: "Should return error if context has no deadline", + name: "task-3", + maxConcurrency: 1, + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return context.Background(), func() {} + }, + errStr: "provided context must have a deadline", + }, + { + desc: "Should return error when context is missing taskID", + name: "task-4", + maxConcurrency: 1, + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + errStr: "provided context is missing task ID value", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + _, err := sema.Acquire(ctx) + if err == nil || err.Error() != tt.errStr { + t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.errStr) + } + + cancel() + } + }) + } +} + +func TestNewSemaphore_Acquire_StaleToken(t *testing.T) { opt := getRedisConnOpt(t) rc := opt.MakeRedisClient().(redis.UniversalClient) defer rc.Close() taskID := uuid.New() - rc.ZAdd(context.Background(), semaphoreKey("stale-lock"), &redis.Z{ + // adding a set member to mimic the case where token is acquired but the goroutine crashed, + // in which case, the token will not be explicitly removed and should be present already + rc.ZAdd(context.Background(), semaphoreKey("stale-token"), &redis.Z{ Score: float64(time.Now().Add(-10 * time.Second).Unix()), Member: taskID.String(), }) - sema := NewSemaphore(opt, "stale-lock", 1) + sema := NewSemaphore(opt, "stale-token", 1) defer sema.Close() ctx, cancel := asynqcontext.New(&base.TaskMessage{ @@ -188,27 +226,24 @@ func TestNewSemaphore_Acquire_StaleLock(t *testing.T) { got, err := sema.Acquire(ctx) if err != nil { - t.Errorf("Acquire_StaleLock;\nSemaphore.Acquire() got error %v", err) + t.Errorf("Acquire_StaleToken;\nSemaphore.Acquire() got error %v", err) } if !got { - t.Error("Acquire_StaleLock;\nSemaphore.Acquire() got false want true") + t.Error("Acquire_StaleToken;\nSemaphore.Acquire() got false want true") } } func TestNewSemaphore_Release(t *testing.T) { - testID := uuid.New() - tests := []struct { desc string name string taskIDs []uuid.UUID ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) wantCount int64 - wantErr string }{ { - desc: "Should decrease lock count", + desc: "Should decrease token count", name: "task-5", taskIDs: []uuid.UUID{uuid.New()}, ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { @@ -219,7 +254,7 @@ func TestNewSemaphore_Release(t *testing.T) { }, }, { - desc: "Should decrease lock count by 2", + desc: "Should decrease token count by 2", name: "task-6", taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { @@ -229,29 +264,6 @@ func TestNewSemaphore_Release(t *testing.T) { }, time.Now().Add(time.Second)) }, }, - { - desc: "Should return error when context is missing taskID", - name: "task-7", - taskIDs: []uuid.UUID{uuid.New()}, - ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), time.Second) - }, - wantCount: 1, - wantErr: "provided context is missing task ID value", - }, - { - desc: "Should return error when context has taskID which never acquired lock", - name: "task-8", - taskIDs: []uuid.UUID{uuid.New()}, - ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { - return asynqcontext.New(&base.TaskMessage{ - ID: testID, - Queue: "task-4", - }, time.Now().Add(time.Second)) - }, - wantCount: 1, - wantErr: fmt.Sprintf("no token found for task %q", testID.String()), - }, } for _, tt := range tests { @@ -281,16 +293,10 @@ func TestNewSemaphore_Release(t *testing.T) { for i := 0; i < len(tt.taskIDs); i++ { ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) - err := sema.Release(ctx) - - if tt.wantErr == "" && err != nil { + if err := sema.Release(ctx); err != nil { t.Errorf("%s;\nSemaphore.Release() got error %v", tt.desc, err) } - if tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr) { - t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.wantErr) - } - cancel() } @@ -300,7 +306,77 @@ func TestNewSemaphore_Release(t *testing.T) { } if i != tt.wantCount { - t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want 0", tt.desc, i) + t.Errorf("%s;\nSemaphore.Release(ctx) didn't release token, got %v want 0", tt.desc, i) + } + }) + } +} + +func TestNewSemaphore_Release_Error(t *testing.T) { + testID := uuid.New() + + tests := []struct { + desc string + name string + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + errStr string + }{ + { + desc: "Should return error when context is missing taskID", + name: "task-7", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + errStr: "provided context is missing task ID value", + }, + { + desc: "Should return error when context has taskID which never acquired token", + name: "task-8", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: testID, + Queue: "task-4", + }, time.Now().Add(time.Second)) + }, + errStr: fmt.Sprintf("no token found for task %q", testID.String()), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + var members []*redis.Z + for i := 0; i < len(tt.taskIDs); i++ { + members = append(members, &redis.Z{ + Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()), + Member: tt.taskIDs[i].String(), + }) + } + if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, 3) + defer sema.Close() + + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + if err := sema.Release(ctx); err == nil || err.Error() != tt.errStr { + t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.errStr) + } + + cancel() } }) }