content_fields_test.go

 1package ant
 2
 3import (
 4	"encoding/json"
 5	"testing"
 6
 7	"shelley.exe.dev/llm"
 8)
 9
10// TestTextContentNoExtraFields verifies that text content doesn't include fields from other content types
11func TestTextContentNoExtraFields(t *testing.T) {
12	tests := []struct {
13		name          string
14		content       llm.Content
15		allowedFields map[string]bool
16	}{
17		{
18			name: "text content",
19			content: llm.Content{
20				Type: llm.ContentTypeText,
21				Text: "Hello world",
22			},
23			allowedFields: map[string]bool{
24				"type": true,
25				"text": true,
26			},
27		},
28		{
29			name: "tool_use content",
30			content: llm.Content{
31				Type:      llm.ContentTypeToolUse,
32				ID:        "toolu_123",
33				ToolName:  "bash",
34				ToolInput: json.RawMessage(`{"command":"ls"}`),
35			},
36			allowedFields: map[string]bool{
37				"type":  true,
38				"id":    true,
39				"name":  true,
40				"input": true,
41			},
42		},
43		{
44			name: "tool_result content",
45			content: llm.Content{
46				Type:      llm.ContentTypeToolResult,
47				ToolUseID: "toolu_123",
48				ToolResult: []llm.Content{
49					{Type: llm.ContentTypeText, Text: "result"},
50				},
51			},
52			allowedFields: map[string]bool{
53				"type":        true,
54				"tool_use_id": true,
55				"content":     true,
56			},
57		},
58	}
59
60	for _, tt := range tests {
61		t.Run(tt.name, func(t *testing.T) {
62			antContent := fromLLMContent(tt.content)
63			jsonBytes, err := json.Marshal(antContent)
64			if err != nil {
65				t.Fatalf("failed to marshal content: %v", err)
66			}
67
68			var result map[string]interface{}
69			if err := json.Unmarshal(jsonBytes, &result); err != nil {
70				t.Fatalf("failed to unmarshal JSON: %v", err)
71			}
72
73			// Check that only allowed fields are present
74			for field := range result {
75				if !tt.allowedFields[field] {
76					t.Errorf("unexpected field %q in %s content: %s", field, tt.name, string(jsonBytes))
77				}
78			}
79
80			// Check that all required fields are present
81			for field := range tt.allowedFields {
82				if _, ok := result[field]; !ok && field != "cache_control" {
83					// cache_control is optional, so we don't require it
84					if field != "content" || tt.content.Type == llm.ContentTypeToolResult {
85						// Only check for content field if it's a tool_result
86						if field == "content" && tt.content.Type == llm.ContentTypeToolResult {
87							t.Errorf("missing required field %q in %s content: %s", field, tt.name, string(jsonBytes))
88						}
89					}
90				}
91			}
92		})
93	}
94}