From 07d065d780610a4a37e516733da951888293bf4b Mon Sep 17 00:00:00 2001 From: "wanghuaiyu@qiniu.com" Date: Sun, 15 Feb 2026 22:43:35 +0800 Subject: [PATCH] refactor: simplify context value retrieval using generics - Introduce generic getContextValue helper function to eliminate code duplication - Reduce code from 75 to 56 lines (25.3% reduction) - Simplify Get*FromContext functions from ~13 lines to 2 lines each - Add comprehensive test coverage (18 test cases) for context functions - Maintain backward compatibility with existing API --- internal/agent/tools/context_test.go | 219 +++++++++++++++++++++++++++ internal/agent/tools/tools.go | 52 +++---- 2 files changed, 236 insertions(+), 35 deletions(-) create mode 100644 internal/agent/tools/context_test.go diff --git a/internal/agent/tools/context_test.go b/internal/agent/tools/context_test.go new file mode 100644 index 0000000000000000000000000000000000000000..67a106bf23f9721fb8cb025dd2a6a7d9f349b188 --- /dev/null +++ b/internal/agent/tools/context_test.go @@ -0,0 +1,219 @@ +package tools + +import ( + "context" + "testing" +) + +func TestGetContextValue(t *testing.T) { + tests := []struct { + name string + setup func(ctx context.Context) context.Context + key any + defaultValue any + want any + }{ + { + name: "returns string value", + setup: func(ctx context.Context) context.Context { + return context.WithValue(ctx, "testKey", "testValue") + }, + key: "testKey", + defaultValue: "", + want: "testValue", + }, + { + name: "returns default when key not found", + setup: func(ctx context.Context) context.Context { + return ctx + }, + key: "missingKey", + defaultValue: "default", + want: "default", + }, + { + name: "returns default when type mismatch", + setup: func(ctx context.Context) context.Context { + return context.WithValue(ctx, "testKey", 123) // int, not string + }, + key: "testKey", + defaultValue: "default", + want: "default", + }, + { + name: "returns bool value", + setup: func(ctx context.Context) context.Context { + return context.WithValue(ctx, "boolKey", true) + }, + key: "boolKey", + defaultValue: false, + want: true, + }, + { + name: "returns int value", + setup: func(ctx context.Context) context.Context { + return context.WithValue(ctx, "intKey", 42) + }, + key: "intKey", + defaultValue: 0, + want: 42, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setup(context.Background()) + + var got any + switch tt.defaultValue.(type) { + case string: + got = getContextValue(ctx, tt.key, tt.defaultValue.(string)) + case bool: + got = getContextValue(ctx, tt.key, tt.defaultValue.(bool)) + case int: + got = getContextValue(ctx, tt.key, tt.defaultValue.(int)) + } + + if got != tt.want { + t.Errorf("getContextValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetSessionFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + want string + }{ + { + name: "returns session ID when present", + ctx: context.WithValue(context.Background(), SessionIDContextKey, "session-123"), + want: "session-123", + }, + { + name: "returns empty string when not present", + ctx: context.Background(), + want: "", + }, + { + name: "returns empty string when wrong type", + ctx: context.WithValue(context.Background(), SessionIDContextKey, 123), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetSessionFromContext(tt.ctx) + if got != tt.want { + t.Errorf("GetSessionFromContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetMessageFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + want string + }{ + { + name: "returns message ID when present", + ctx: context.WithValue(context.Background(), MessageIDContextKey, "msg-456"), + want: "msg-456", + }, + { + name: "returns empty string when not present", + ctx: context.Background(), + want: "", + }, + { + name: "returns empty string when wrong type", + ctx: context.WithValue(context.Background(), MessageIDContextKey, 456), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetMessageFromContext(tt.ctx) + if got != tt.want { + t.Errorf("GetMessageFromContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetSupportsImagesFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + want bool + }{ + { + name: "returns true when present and true", + ctx: context.WithValue(context.Background(), SupportsImagesContextKey, true), + want: true, + }, + { + name: "returns false when present and false", + ctx: context.WithValue(context.Background(), SupportsImagesContextKey, false), + want: false, + }, + { + name: "returns false when not present", + ctx: context.Background(), + want: false, + }, + { + name: "returns false when wrong type", + ctx: context.WithValue(context.Background(), SupportsImagesContextKey, "true"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetSupportsImagesFromContext(tt.ctx) + if got != tt.want { + t.Errorf("GetSupportsImagesFromContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetModelNameFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + want string + }{ + { + name: "returns model name when present", + ctx: context.WithValue(context.Background(), ModelNameContextKey, "claude-opus-4"), + want: "claude-opus-4", + }, + { + name: "returns empty string when not present", + ctx: context.Background(), + want: "", + }, + { + name: "returns empty string when wrong type", + ctx: context.WithValue(context.Background(), ModelNameContextKey, 789), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetModelNameFromContext(tt.ctx) + if got != tt.want { + t.Errorf("GetModelNameFromContext() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/agent/tools/tools.go b/internal/agent/tools/tools.go index 7d03d0e22714205f0883de5b15960a104f9f6b98..50a2f7af24f9b1bc920fb88bc9a0df1123db9ebc 100644 --- a/internal/agent/tools/tools.go +++ b/internal/agent/tools/tools.go @@ -22,53 +22,35 @@ const ( ModelNameContextKey modelNameKey = "model_name" ) -// GetSessionFromContext retrieves the session ID from the context. -func GetSessionFromContext(ctx context.Context) string { - sessionID := ctx.Value(SessionIDContextKey) - if sessionID == nil { - return "" +// getContextValue is a generic helper that retrieves a typed value from context. +// If the value is not found or has the wrong type, it returns the default value. +func getContextValue[T any](ctx context.Context, key any, defaultValue T) T { + value := ctx.Value(key) + if value == nil { + return defaultValue } - s, ok := sessionID.(string) - if !ok { - return "" + if typedValue, ok := value.(T); ok { + return typedValue } - return s + return defaultValue +} + +// GetSessionFromContext retrieves the session ID from the context. +func GetSessionFromContext(ctx context.Context) string { + return getContextValue(ctx, SessionIDContextKey, "") } // GetMessageFromContext retrieves the message ID from the context. func GetMessageFromContext(ctx context.Context) string { - messageID := ctx.Value(MessageIDContextKey) - if messageID == nil { - return "" - } - s, ok := messageID.(string) - if !ok { - return "" - } - return s + return getContextValue(ctx, MessageIDContextKey, "") } // GetSupportsImagesFromContext retrieves whether the model supports images from the context. func GetSupportsImagesFromContext(ctx context.Context) bool { - supportsImages := ctx.Value(SupportsImagesContextKey) - if supportsImages == nil { - return false - } - if supports, ok := supportsImages.(bool); ok { - return supports - } - return false + return getContextValue(ctx, SupportsImagesContextKey, false) } // GetModelNameFromContext retrieves the model name from the context. func GetModelNameFromContext(ctx context.Context) string { - modelName := ctx.Value(ModelNameContextKey) - if modelName == nil { - return "" - } - s, ok := modelName.(string) - if !ok { - return "" - } - return s + return getContextValue(ctx, ModelNameContextKey, "") }