refactor: simplify context value retrieval using generics

wanghuaiyu@qiniu.com created

- Introduce generic getContextValue helper function to eliminate code duplication
- Reduce code from 75 to 56 lines (25.3% reduction)
- Simplify Get*FromContext functions from ~13 lines to 2 lines each
- Add comprehensive test coverage (18 test cases) for context functions
- Maintain backward compatibility with existing API

Change summary

internal/agent/tools/context_test.go | 219 ++++++++++++++++++++++++++++++
internal/agent/tools/tools.go        |  52 ++----
2 files changed, 236 insertions(+), 35 deletions(-)

Detailed changes

internal/agent/tools/context_test.go 🔗

@@ -0,0 +1,219 @@
+package tools
+
+import (
+	"context"
+	"testing"
+)
+
+func TestGetContextValue(t *testing.T) {
+	tests := []struct {
+		name         string
+		setup        func(ctx context.Context) context.Context
+		key          any
+		defaultValue any
+		want         any
+	}{
+		{
+			name: "returns string value",
+			setup: func(ctx context.Context) context.Context {
+				return context.WithValue(ctx, "testKey", "testValue")
+			},
+			key:          "testKey",
+			defaultValue: "",
+			want:         "testValue",
+		},
+		{
+			name: "returns default when key not found",
+			setup: func(ctx context.Context) context.Context {
+				return ctx
+			},
+			key:          "missingKey",
+			defaultValue: "default",
+			want:         "default",
+		},
+		{
+			name: "returns default when type mismatch",
+			setup: func(ctx context.Context) context.Context {
+				return context.WithValue(ctx, "testKey", 123) // int, not string
+			},
+			key:          "testKey",
+			defaultValue: "default",
+			want:         "default",
+		},
+		{
+			name: "returns bool value",
+			setup: func(ctx context.Context) context.Context {
+				return context.WithValue(ctx, "boolKey", true)
+			},
+			key:          "boolKey",
+			defaultValue: false,
+			want:         true,
+		},
+		{
+			name: "returns int value",
+			setup: func(ctx context.Context) context.Context {
+				return context.WithValue(ctx, "intKey", 42)
+			},
+			key:          "intKey",
+			defaultValue: 0,
+			want:         42,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			ctx := tt.setup(context.Background())
+
+			var got any
+			switch tt.defaultValue.(type) {
+			case string:
+				got = getContextValue(ctx, tt.key, tt.defaultValue.(string))
+			case bool:
+				got = getContextValue(ctx, tt.key, tt.defaultValue.(bool))
+			case int:
+				got = getContextValue(ctx, tt.key, tt.defaultValue.(int))
+			}
+
+			if got != tt.want {
+				t.Errorf("getContextValue() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestGetSessionFromContext(t *testing.T) {
+	tests := []struct {
+		name string
+		ctx  context.Context
+		want string
+	}{
+		{
+			name: "returns session ID when present",
+			ctx:  context.WithValue(context.Background(), SessionIDContextKey, "session-123"),
+			want: "session-123",
+		},
+		{
+			name: "returns empty string when not present",
+			ctx:  context.Background(),
+			want: "",
+		},
+		{
+			name: "returns empty string when wrong type",
+			ctx:  context.WithValue(context.Background(), SessionIDContextKey, 123),
+			want: "",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got := GetSessionFromContext(tt.ctx)
+			if got != tt.want {
+				t.Errorf("GetSessionFromContext() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestGetMessageFromContext(t *testing.T) {
+	tests := []struct {
+		name string
+		ctx  context.Context
+		want string
+	}{
+		{
+			name: "returns message ID when present",
+			ctx:  context.WithValue(context.Background(), MessageIDContextKey, "msg-456"),
+			want: "msg-456",
+		},
+		{
+			name: "returns empty string when not present",
+			ctx:  context.Background(),
+			want: "",
+		},
+		{
+			name: "returns empty string when wrong type",
+			ctx:  context.WithValue(context.Background(), MessageIDContextKey, 456),
+			want: "",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got := GetMessageFromContext(tt.ctx)
+			if got != tt.want {
+				t.Errorf("GetMessageFromContext() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestGetSupportsImagesFromContext(t *testing.T) {
+	tests := []struct {
+		name string
+		ctx  context.Context
+		want bool
+	}{
+		{
+			name: "returns true when present and true",
+			ctx:  context.WithValue(context.Background(), SupportsImagesContextKey, true),
+			want: true,
+		},
+		{
+			name: "returns false when present and false",
+			ctx:  context.WithValue(context.Background(), SupportsImagesContextKey, false),
+			want: false,
+		},
+		{
+			name: "returns false when not present",
+			ctx:  context.Background(),
+			want: false,
+		},
+		{
+			name: "returns false when wrong type",
+			ctx:  context.WithValue(context.Background(), SupportsImagesContextKey, "true"),
+			want: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got := GetSupportsImagesFromContext(tt.ctx)
+			if got != tt.want {
+				t.Errorf("GetSupportsImagesFromContext() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestGetModelNameFromContext(t *testing.T) {
+	tests := []struct {
+		name string
+		ctx  context.Context
+		want string
+	}{
+		{
+			name: "returns model name when present",
+			ctx:  context.WithValue(context.Background(), ModelNameContextKey, "claude-opus-4"),
+			want: "claude-opus-4",
+		},
+		{
+			name: "returns empty string when not present",
+			ctx:  context.Background(),
+			want: "",
+		},
+		{
+			name: "returns empty string when wrong type",
+			ctx:  context.WithValue(context.Background(), ModelNameContextKey, 789),
+			want: "",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got := GetModelNameFromContext(tt.ctx)
+			if got != tt.want {
+				t.Errorf("GetModelNameFromContext() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}

internal/agent/tools/tools.go 🔗

@@ -22,53 +22,35 @@ const (
 	ModelNameContextKey modelNameKey = "model_name"
 )
 
-// GetSessionFromContext retrieves the session ID from the context.
-func GetSessionFromContext(ctx context.Context) string {
-	sessionID := ctx.Value(SessionIDContextKey)
-	if sessionID == nil {
-		return ""
+// getContextValue is a generic helper that retrieves a typed value from context.
+// If the value is not found or has the wrong type, it returns the default value.
+func getContextValue[T any](ctx context.Context, key any, defaultValue T) T {
+	value := ctx.Value(key)
+	if value == nil {
+		return defaultValue
 	}
-	s, ok := sessionID.(string)
-	if !ok {
-		return ""
+	if typedValue, ok := value.(T); ok {
+		return typedValue
 	}
-	return s
+	return defaultValue
+}
+
+// GetSessionFromContext retrieves the session ID from the context.
+func GetSessionFromContext(ctx context.Context) string {
+	return getContextValue(ctx, SessionIDContextKey, "")
 }
 
 // GetMessageFromContext retrieves the message ID from the context.
 func GetMessageFromContext(ctx context.Context) string {
-	messageID := ctx.Value(MessageIDContextKey)
-	if messageID == nil {
-		return ""
-	}
-	s, ok := messageID.(string)
-	if !ok {
-		return ""
-	}
-	return s
+	return getContextValue(ctx, MessageIDContextKey, "")
 }
 
 // GetSupportsImagesFromContext retrieves whether the model supports images from the context.
 func GetSupportsImagesFromContext(ctx context.Context) bool {
-	supportsImages := ctx.Value(SupportsImagesContextKey)
-	if supportsImages == nil {
-		return false
-	}
-	if supports, ok := supportsImages.(bool); ok {
-		return supports
-	}
-	return false
+	return getContextValue(ctx, SupportsImagesContextKey, false)
 }
 
 // GetModelNameFromContext retrieves the model name from the context.
 func GetModelNameFromContext(ctx context.Context) string {
-	modelName := ctx.Value(ModelNameContextKey)
-	if modelName == nil {
-		return ""
-	}
-	s, ok := modelName.(string)
-	if !ok {
-		return ""
-	}
-	return s
+	return getContextValue(ctx, ModelNameContextKey, "")
 }