1package fantasy
2
3import (
4 "encoding/json"
5 "errors"
6 "reflect"
7 "testing"
8)
9
10func TestMessageJSONSerialization(t *testing.T) {
11 tests := []struct {
12 name string
13 message Message
14 }{
15 {
16 name: "simple text message",
17 message: Message{
18 Role: MessageRoleUser,
19 Content: []MessagePart{
20 TextPart{Text: "Hello, world!"},
21 },
22 },
23 },
24 {
25 name: "message with multiple text parts",
26 message: Message{
27 Role: MessageRoleAssistant,
28 Content: []MessagePart{
29 TextPart{Text: "First part"},
30 TextPart{Text: "Second part"},
31 TextPart{Text: "Third part"},
32 },
33 },
34 },
35 {
36 name: "message with reasoning part",
37 message: Message{
38 Role: MessageRoleAssistant,
39 Content: []MessagePart{
40 ReasoningPart{Text: "Let me think about this..."},
41 TextPart{Text: "Here's my answer"},
42 },
43 },
44 },
45 {
46 name: "message with file part",
47 message: Message{
48 Role: MessageRoleUser,
49 Content: []MessagePart{
50 TextPart{Text: "Here's an image:"},
51 FilePart{
52 Filename: "test.png",
53 Data: []byte{0x89, 0x50, 0x4E, 0x47}, // PNG header
54 MediaType: "image/png",
55 },
56 },
57 },
58 },
59 {
60 name: "message with tool call",
61 message: Message{
62 Role: MessageRoleAssistant,
63 Content: []MessagePart{
64 ToolCallPart{
65 ToolCallID: "call_123",
66 ToolName: "get_weather",
67 Input: `{"location": "San Francisco"}`,
68 ProviderExecuted: false,
69 },
70 },
71 },
72 },
73 {
74 name: "message with tool result - text output",
75 message: Message{
76 Role: MessageRoleTool,
77 Content: []MessagePart{
78 ToolResultPart{
79 ToolCallID: "call_123",
80 Output: ToolResultOutputContentText{
81 Text: "The weather is sunny, 72ยฐF",
82 },
83 },
84 },
85 },
86 },
87 {
88 name: "message with tool result - error output",
89 message: Message{
90 Role: MessageRoleTool,
91 Content: []MessagePart{
92 ToolResultPart{
93 ToolCallID: "call_456",
94 Output: ToolResultOutputContentError{
95 Error: errors.New("API rate limit exceeded"),
96 },
97 },
98 },
99 },
100 },
101 {
102 name: "message with tool result - media output",
103 message: Message{
104 Role: MessageRoleTool,
105 Content: []MessagePart{
106 ToolResultPart{
107 ToolCallID: "call_789",
108 Output: ToolResultOutputContentMedia{
109 Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
110 MediaType: "image/png",
111 },
112 },
113 },
114 },
115 },
116 {
117 name: "complex message with mixed content",
118 message: Message{
119 Role: MessageRoleAssistant,
120 Content: []MessagePart{
121 TextPart{Text: "I'll analyze this image and call some tools."},
122 ReasoningPart{Text: "First, I need to identify the objects..."},
123 ToolCallPart{
124 ToolCallID: "call_001",
125 ToolName: "analyze_image",
126 Input: `{"image_id": "img_123"}`,
127 ProviderExecuted: false,
128 },
129 ToolCallPart{
130 ToolCallID: "call_002",
131 ToolName: "get_context",
132 Input: `{"query": "similar images"}`,
133 ProviderExecuted: true,
134 },
135 },
136 },
137 },
138 {
139 name: "system message",
140 message: Message{
141 Role: MessageRoleSystem,
142 Content: []MessagePart{
143 TextPart{Text: "You are a helpful assistant."},
144 },
145 },
146 },
147 {
148 name: "empty content",
149 message: Message{
150 Role: MessageRoleUser,
151 Content: []MessagePart{},
152 },
153 },
154 }
155
156 for _, tt := range tests {
157 t.Run(tt.name, func(t *testing.T) {
158 // Marshal the message
159 data, err := json.Marshal(tt.message)
160 if err != nil {
161 t.Fatalf("failed to marshal message: %v", err)
162 }
163
164 // Unmarshal back
165 var decoded Message
166 err = json.Unmarshal(data, &decoded)
167 if err != nil {
168 t.Fatalf("failed to unmarshal message: %v", err)
169 }
170
171 // Compare roles
172 if decoded.Role != tt.message.Role {
173 t.Errorf("role mismatch: got %v, want %v", decoded.Role, tt.message.Role)
174 }
175
176 // Compare content length
177 if len(decoded.Content) != len(tt.message.Content) {
178 t.Fatalf("content length mismatch: got %d, want %d", len(decoded.Content), len(tt.message.Content))
179 }
180
181 // Compare each content part
182 for i := range tt.message.Content {
183 original := tt.message.Content[i]
184 decodedPart := decoded.Content[i]
185
186 if original.GetType() != decodedPart.GetType() {
187 t.Errorf("content[%d] type mismatch: got %v, want %v", i, decodedPart.GetType(), original.GetType())
188 continue
189 }
190
191 compareMessagePart(t, i, original, decodedPart)
192 }
193 })
194 }
195}
196
197func compareMessagePart(t *testing.T, index int, original, decoded MessagePart) {
198 switch original.GetType() {
199 case ContentTypeText:
200 orig := original.(TextPart)
201 dec := decoded.(TextPart)
202 if orig.Text != dec.Text {
203 t.Errorf("content[%d] text mismatch: got %q, want %q", index, dec.Text, orig.Text)
204 }
205
206 case ContentTypeReasoning:
207 orig := original.(ReasoningPart)
208 dec := decoded.(ReasoningPart)
209 if orig.Text != dec.Text {
210 t.Errorf("content[%d] reasoning text mismatch: got %q, want %q", index, dec.Text, orig.Text)
211 }
212
213 case ContentTypeFile:
214 orig := original.(FilePart)
215 dec := decoded.(FilePart)
216 if orig.Filename != dec.Filename {
217 t.Errorf("content[%d] filename mismatch: got %q, want %q", index, dec.Filename, orig.Filename)
218 }
219 if orig.MediaType != dec.MediaType {
220 t.Errorf("content[%d] media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType)
221 }
222 if !reflect.DeepEqual(orig.Data, dec.Data) {
223 t.Errorf("content[%d] file data mismatch", index)
224 }
225
226 case ContentTypeToolCall:
227 orig := original.(ToolCallPart)
228 dec := decoded.(ToolCallPart)
229 if orig.ToolCallID != dec.ToolCallID {
230 t.Errorf("content[%d] tool call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID)
231 }
232 if orig.ToolName != dec.ToolName {
233 t.Errorf("content[%d] tool name mismatch: got %q, want %q", index, dec.ToolName, orig.ToolName)
234 }
235 if orig.Input != dec.Input {
236 t.Errorf("content[%d] tool input mismatch: got %q, want %q", index, dec.Input, orig.Input)
237 }
238 if orig.ProviderExecuted != dec.ProviderExecuted {
239 t.Errorf("content[%d] provider executed mismatch: got %v, want %v", index, dec.ProviderExecuted, orig.ProviderExecuted)
240 }
241
242 case ContentTypeToolResult:
243 orig := original.(ToolResultPart)
244 dec := decoded.(ToolResultPart)
245 if orig.ToolCallID != dec.ToolCallID {
246 t.Errorf("content[%d] tool result call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID)
247 }
248 compareToolResultOutput(t, index, orig.Output, dec.Output)
249 }
250}
251
252func compareToolResultOutput(t *testing.T, index int, original, decoded ToolResultOutputContent) {
253 if original.GetType() != decoded.GetType() {
254 t.Errorf("content[%d] tool result output type mismatch: got %v, want %v", index, decoded.GetType(), original.GetType())
255 return
256 }
257
258 switch original.GetType() {
259 case ToolResultContentTypeText:
260 orig := original.(ToolResultOutputContentText)
261 dec := decoded.(ToolResultOutputContentText)
262 if orig.Text != dec.Text {
263 t.Errorf("content[%d] tool result text mismatch: got %q, want %q", index, dec.Text, orig.Text)
264 }
265
266 case ToolResultContentTypeError:
267 orig := original.(ToolResultOutputContentError)
268 dec := decoded.(ToolResultOutputContentError)
269 if orig.Error.Error() != dec.Error.Error() {
270 t.Errorf("content[%d] tool result error mismatch: got %q, want %q", index, dec.Error.Error(), orig.Error.Error())
271 }
272
273 case ToolResultContentTypeMedia:
274 orig := original.(ToolResultOutputContentMedia)
275 dec := decoded.(ToolResultOutputContentMedia)
276 if orig.Data != dec.Data {
277 t.Errorf("content[%d] tool result media data mismatch", index)
278 }
279 if orig.MediaType != dec.MediaType {
280 t.Errorf("content[%d] tool result media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType)
281 }
282 }
283}
284
285func TestHelperFunctions(t *testing.T) {
286 t.Run("NewUserMessage - text only", func(t *testing.T) {
287 msg := NewUserMessage("Hello")
288
289 data, err := json.Marshal(msg)
290 if err != nil {
291 t.Fatalf("failed to marshal: %v", err)
292 }
293
294 var decoded Message
295 if err := json.Unmarshal(data, &decoded); err != nil {
296 t.Fatalf("failed to unmarshal: %v", err)
297 }
298
299 if decoded.Role != MessageRoleUser {
300 t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleUser)
301 }
302
303 if len(decoded.Content) != 1 {
304 t.Fatalf("expected 1 content part, got %d", len(decoded.Content))
305 }
306
307 textPart := decoded.Content[0].(TextPart)
308 if textPart.Text != "Hello" {
309 t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Hello")
310 }
311 })
312
313 t.Run("NewUserMessage - with files", func(t *testing.T) {
314 msg := NewUserMessage("Check this image",
315 FilePart{
316 Filename: "image1.jpg",
317 Data: []byte{0xFF, 0xD8, 0xFF},
318 MediaType: "image/jpeg",
319 },
320 FilePart{
321 Filename: "image2.png",
322 Data: []byte{0x89, 0x50, 0x4E, 0x47},
323 MediaType: "image/png",
324 },
325 )
326
327 data, err := json.Marshal(msg)
328 if err != nil {
329 t.Fatalf("failed to marshal: %v", err)
330 }
331
332 var decoded Message
333 if err := json.Unmarshal(data, &decoded); err != nil {
334 t.Fatalf("failed to unmarshal: %v", err)
335 }
336
337 if len(decoded.Content) != 3 {
338 t.Fatalf("expected 3 content parts, got %d", len(decoded.Content))
339 }
340
341 // Check text part
342 textPart := decoded.Content[0].(TextPart)
343 if textPart.Text != "Check this image" {
344 t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Check this image")
345 }
346
347 // Check first file
348 file1 := decoded.Content[1].(FilePart)
349 if file1.Filename != "image1.jpg" {
350 t.Errorf("file1 name mismatch: got %q, want %q", file1.Filename, "image1.jpg")
351 }
352
353 // Check second file
354 file2 := decoded.Content[2].(FilePart)
355 if file2.Filename != "image2.png" {
356 t.Errorf("file2 name mismatch: got %q, want %q", file2.Filename, "image2.png")
357 }
358 })
359
360 t.Run("NewSystemMessage - single prompt", func(t *testing.T) {
361 msg := NewSystemMessage("You are a helpful assistant.")
362
363 data, err := json.Marshal(msg)
364 if err != nil {
365 t.Fatalf("failed to marshal: %v", err)
366 }
367
368 var decoded Message
369 if err := json.Unmarshal(data, &decoded); err != nil {
370 t.Fatalf("failed to unmarshal: %v", err)
371 }
372
373 if decoded.Role != MessageRoleSystem {
374 t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleSystem)
375 }
376
377 if len(decoded.Content) != 1 {
378 t.Fatalf("expected 1 content part, got %d", len(decoded.Content))
379 }
380
381 textPart := decoded.Content[0].(TextPart)
382 if textPart.Text != "You are a helpful assistant." {
383 t.Errorf("text mismatch: got %q, want %q", textPart.Text, "You are a helpful assistant.")
384 }
385 })
386
387 t.Run("NewSystemMessage - multiple prompts", func(t *testing.T) {
388 msg := NewSystemMessage("First instruction", "Second instruction", "Third instruction")
389
390 data, err := json.Marshal(msg)
391 if err != nil {
392 t.Fatalf("failed to marshal: %v", err)
393 }
394
395 var decoded Message
396 if err := json.Unmarshal(data, &decoded); err != nil {
397 t.Fatalf("failed to unmarshal: %v", err)
398 }
399
400 if len(decoded.Content) != 3 {
401 t.Fatalf("expected 3 content parts, got %d", len(decoded.Content))
402 }
403
404 expected := []string{"First instruction", "Second instruction", "Third instruction"}
405 for i, exp := range expected {
406 textPart := decoded.Content[i].(TextPart)
407 if textPart.Text != exp {
408 t.Errorf("content[%d] text mismatch: got %q, want %q", i, textPart.Text, exp)
409 }
410 }
411 })
412}
413
414func TestEdgeCases(t *testing.T) {
415 t.Run("empty text part", func(t *testing.T) {
416 msg := Message{
417 Role: MessageRoleUser,
418 Content: []MessagePart{
419 TextPart{Text: ""},
420 },
421 }
422
423 data, err := json.Marshal(msg)
424 if err != nil {
425 t.Fatalf("failed to marshal: %v", err)
426 }
427
428 var decoded Message
429 if err := json.Unmarshal(data, &decoded); err != nil {
430 t.Fatalf("failed to unmarshal: %v", err)
431 }
432
433 textPart := decoded.Content[0].(TextPart)
434 if textPart.Text != "" {
435 t.Errorf("expected empty text, got %q", textPart.Text)
436 }
437 })
438
439 t.Run("nil error in tool result", func(t *testing.T) {
440 msg := Message{
441 Role: MessageRoleTool,
442 Content: []MessagePart{
443 ToolResultPart{
444 ToolCallID: "call_123",
445 Output: ToolResultOutputContentError{
446 Error: nil,
447 },
448 },
449 },
450 }
451
452 data, err := json.Marshal(msg)
453 if err != nil {
454 t.Fatalf("failed to marshal: %v", err)
455 }
456
457 var decoded Message
458 if err := json.Unmarshal(data, &decoded); err != nil {
459 t.Fatalf("failed to unmarshal: %v", err)
460 }
461
462 toolResult := decoded.Content[0].(ToolResultPart)
463 errorOutput := toolResult.Output.(ToolResultOutputContentError)
464 if errorOutput.Error != nil {
465 t.Errorf("expected nil error, got %v", errorOutput.Error)
466 }
467 })
468
469 t.Run("empty file data", func(t *testing.T) {
470 msg := Message{
471 Role: MessageRoleUser,
472 Content: []MessagePart{
473 FilePart{
474 Filename: "empty.txt",
475 Data: []byte{},
476 MediaType: "text/plain",
477 },
478 },
479 }
480
481 data, err := json.Marshal(msg)
482 if err != nil {
483 t.Fatalf("failed to marshal: %v", err)
484 }
485
486 var decoded Message
487 if err := json.Unmarshal(data, &decoded); err != nil {
488 t.Fatalf("failed to unmarshal: %v", err)
489 }
490
491 filePart := decoded.Content[0].(FilePart)
492 if len(filePart.Data) != 0 {
493 t.Errorf("expected empty data, got %d bytes", len(filePart.Data))
494 }
495 })
496
497 t.Run("unicode in text", func(t *testing.T) {
498 msg := Message{
499 Role: MessageRoleUser,
500 Content: []MessagePart{
501 TextPart{Text: "Hello ไธ็! ๐ ะัะธะฒะตั"},
502 },
503 }
504
505 data, err := json.Marshal(msg)
506 if err != nil {
507 t.Fatalf("failed to marshal: %v", err)
508 }
509
510 var decoded Message
511 if err := json.Unmarshal(data, &decoded); err != nil {
512 t.Fatalf("failed to unmarshal: %v", err)
513 }
514
515 textPart := decoded.Content[0].(TextPart)
516 if textPart.Text != "Hello ไธ็! ๐ ะัะธะฒะตั" {
517 t.Errorf("unicode text mismatch: got %q, want %q", textPart.Text, "Hello ไธ็! ๐ ะัะธะฒะตั")
518 }
519 })
520}
521
522func TestInvalidJSONHandling(t *testing.T) {
523 t.Run("unknown message part type", func(t *testing.T) {
524 invalidJSON := `{
525 "role": "user",
526 "content": [
527 {
528 "type": "unknown-type",
529 "data": {}
530 }
531 ],
532 "provider_options": null
533 }`
534
535 var msg Message
536 err := json.Unmarshal([]byte(invalidJSON), &msg)
537 if err == nil {
538 t.Error("expected error for unknown message part type, got nil")
539 }
540 })
541
542 t.Run("unknown tool result output type", func(t *testing.T) {
543 invalidJSON := `{
544 "role": "tool",
545 "content": [
546 {
547 "type": "tool-result",
548 "data": {
549 "tool_call_id": "call_123",
550 "output": {
551 "type": "unknown-output-type",
552 "data": {}
553 },
554 "provider_options": null
555 }
556 }
557 ],
558 "provider_options": null
559 }`
560
561 var msg Message
562 err := json.Unmarshal([]byte(invalidJSON), &msg)
563 if err == nil {
564 t.Error("expected error for unknown tool result output type, got nil")
565 }
566 })
567
568 t.Run("malformed JSON", func(t *testing.T) {
569 invalidJSON := `{"role": "user", "content": [`
570
571 var msg Message
572 err := json.Unmarshal([]byte(invalidJSON), &msg)
573 if err == nil {
574 t.Error("expected error for malformed JSON, got nil")
575 }
576 })
577}
578
579// Mock provider data for testing provider options
580type mockProviderData struct {
581 Key string `json:"key"`
582}
583
584func (m mockProviderData) Options() {}
585func (m mockProviderData) Type() string { return "mock" }
586func (m mockProviderData) MarshalJSON() ([]byte, error) {
587 return json.Marshal(struct {
588 Type string `json:"type"`
589 mockProviderData
590 }{
591 Type: "mock",
592 mockProviderData: m,
593 })
594}
595
596func (m *mockProviderData) UnmarshalJSON(data []byte) error {
597 var aux struct {
598 Type string `json:"type"`
599 mockProviderData
600 }
601 if err := json.Unmarshal(data, &aux); err != nil {
602 return err
603 }
604 *m = aux.mockProviderData
605 return nil
606}
607
608func TestPromptSerialization(t *testing.T) {
609 t.Run("serialize prompt (message slice)", func(t *testing.T) {
610 prompt := Prompt{
611 NewSystemMessage("You are helpful"),
612 NewUserMessage("Hello"),
613 Message{
614 Role: MessageRoleAssistant,
615 Content: []MessagePart{
616 TextPart{Text: "Hi there!"},
617 },
618 },
619 }
620
621 data, err := json.Marshal(prompt)
622 if err != nil {
623 t.Fatalf("failed to marshal prompt: %v", err)
624 }
625
626 var decoded Prompt
627 if err := json.Unmarshal(data, &decoded); err != nil {
628 t.Fatalf("failed to unmarshal prompt: %v", err)
629 }
630
631 if len(decoded) != 3 {
632 t.Fatalf("expected 3 messages, got %d", len(decoded))
633 }
634
635 if decoded[0].Role != MessageRoleSystem {
636 t.Errorf("message 0 role mismatch: got %v, want %v", decoded[0].Role, MessageRoleSystem)
637 }
638
639 if decoded[1].Role != MessageRoleUser {
640 t.Errorf("message 1 role mismatch: got %v, want %v", decoded[1].Role, MessageRoleUser)
641 }
642
643 if decoded[2].Role != MessageRoleAssistant {
644 t.Errorf("message 2 role mismatch: got %v, want %v", decoded[2].Role, MessageRoleAssistant)
645 }
646 })
647}