1package ant
2
3import (
4 "encoding/json"
5 "testing"
6
7 "shelley.exe.dev/llm"
8)
9
10// TestTextContentNoExtraFields verifies that text content doesn't include fields from other content types
11func TestTextContentNoExtraFields(t *testing.T) {
12 tests := []struct {
13 name string
14 content llm.Content
15 allowedFields map[string]bool
16 }{
17 {
18 name: "text content",
19 content: llm.Content{
20 Type: llm.ContentTypeText,
21 Text: "Hello world",
22 },
23 allowedFields: map[string]bool{
24 "type": true,
25 "text": true,
26 },
27 },
28 {
29 name: "tool_use content",
30 content: llm.Content{
31 Type: llm.ContentTypeToolUse,
32 ID: "toolu_123",
33 ToolName: "bash",
34 ToolInput: json.RawMessage(`{"command":"ls"}`),
35 },
36 allowedFields: map[string]bool{
37 "type": true,
38 "id": true,
39 "name": true,
40 "input": true,
41 },
42 },
43 {
44 name: "tool_result content",
45 content: llm.Content{
46 Type: llm.ContentTypeToolResult,
47 ToolUseID: "toolu_123",
48 ToolResult: []llm.Content{
49 {Type: llm.ContentTypeText, Text: "result"},
50 },
51 },
52 allowedFields: map[string]bool{
53 "type": true,
54 "tool_use_id": true,
55 "content": true,
56 },
57 },
58 }
59
60 for _, tt := range tests {
61 t.Run(tt.name, func(t *testing.T) {
62 antContent := fromLLMContent(tt.content)
63 jsonBytes, err := json.Marshal(antContent)
64 if err != nil {
65 t.Fatalf("failed to marshal content: %v", err)
66 }
67
68 var result map[string]interface{}
69 if err := json.Unmarshal(jsonBytes, &result); err != nil {
70 t.Fatalf("failed to unmarshal JSON: %v", err)
71 }
72
73 // Check that only allowed fields are present
74 for field := range result {
75 if !tt.allowedFields[field] {
76 t.Errorf("unexpected field %q in %s content: %s", field, tt.name, string(jsonBytes))
77 }
78 }
79
80 // Check that all required fields are present
81 for field := range tt.allowedFields {
82 if _, ok := result[field]; !ok && field != "cache_control" {
83 // cache_control is optional, so we don't require it
84 if field != "content" || tt.content.Type == llm.ContentTypeToolResult {
85 // Only check for content field if it's a tool_result
86 if field == "content" && tt.content.Type == llm.ContentTypeToolResult {
87 t.Errorf("missing required field %q in %s content: %s", field, tt.name, string(jsonBytes))
88 }
89 }
90 }
91 }
92 })
93 }
94}