json_test.go

  1package fantasy
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"reflect"
  7	"testing"
  8)
  9
 10func TestMessageJSONSerialization(t *testing.T) {
 11	tests := []struct {
 12		name    string
 13		message Message
 14	}{
 15		{
 16			name: "simple text message",
 17			message: Message{
 18				Role: MessageRoleUser,
 19				Content: []MessagePart{
 20					TextPart{Text: "Hello, world!"},
 21				},
 22			},
 23		},
 24		{
 25			name: "message with multiple text parts",
 26			message: Message{
 27				Role: MessageRoleAssistant,
 28				Content: []MessagePart{
 29					TextPart{Text: "First part"},
 30					TextPart{Text: "Second part"},
 31					TextPart{Text: "Third part"},
 32				},
 33			},
 34		},
 35		{
 36			name: "message with reasoning part",
 37			message: Message{
 38				Role: MessageRoleAssistant,
 39				Content: []MessagePart{
 40					ReasoningPart{Text: "Let me think about this..."},
 41					TextPart{Text: "Here's my answer"},
 42				},
 43			},
 44		},
 45		{
 46			name: "message with file part",
 47			message: Message{
 48				Role: MessageRoleUser,
 49				Content: []MessagePart{
 50					TextPart{Text: "Here's an image:"},
 51					FilePart{
 52						Filename:  "test.png",
 53						Data:      []byte{0x89, 0x50, 0x4E, 0x47}, // PNG header
 54						MediaType: "image/png",
 55					},
 56				},
 57			},
 58		},
 59		{
 60			name: "message with tool call",
 61			message: Message{
 62				Role: MessageRoleAssistant,
 63				Content: []MessagePart{
 64					ToolCallPart{
 65						ToolCallID:       "call_123",
 66						ToolName:         "get_weather",
 67						Input:            `{"location": "San Francisco"}`,
 68						ProviderExecuted: false,
 69					},
 70				},
 71			},
 72		},
 73		{
 74			name: "message with tool result - text output",
 75			message: Message{
 76				Role: MessageRoleTool,
 77				Content: []MessagePart{
 78					ToolResultPart{
 79						ToolCallID: "call_123",
 80						Output: ToolResultOutputContentText{
 81							Text: "The weather is sunny, 72ยฐF",
 82						},
 83					},
 84				},
 85			},
 86		},
 87		{
 88			name: "message with tool result - error output",
 89			message: Message{
 90				Role: MessageRoleTool,
 91				Content: []MessagePart{
 92					ToolResultPart{
 93						ToolCallID: "call_456",
 94						Output: ToolResultOutputContentError{
 95							Error: errors.New("API rate limit exceeded"),
 96						},
 97					},
 98				},
 99			},
100		},
101		{
102			name: "message with tool result - media output",
103			message: Message{
104				Role: MessageRoleTool,
105				Content: []MessagePart{
106					ToolResultPart{
107						ToolCallID: "call_789",
108						Output: ToolResultOutputContentMedia{
109							Data:      "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
110							MediaType: "image/png",
111						},
112					},
113				},
114			},
115		},
116		{
117			name: "complex message with mixed content",
118			message: Message{
119				Role: MessageRoleAssistant,
120				Content: []MessagePart{
121					TextPart{Text: "I'll analyze this image and call some tools."},
122					ReasoningPart{Text: "First, I need to identify the objects..."},
123					ToolCallPart{
124						ToolCallID:       "call_001",
125						ToolName:         "analyze_image",
126						Input:            `{"image_id": "img_123"}`,
127						ProviderExecuted: false,
128					},
129					ToolCallPart{
130						ToolCallID:       "call_002",
131						ToolName:         "get_context",
132						Input:            `{"query": "similar images"}`,
133						ProviderExecuted: true,
134					},
135				},
136			},
137		},
138		{
139			name: "system message",
140			message: Message{
141				Role: MessageRoleSystem,
142				Content: []MessagePart{
143					TextPart{Text: "You are a helpful assistant."},
144				},
145			},
146		},
147		{
148			name: "empty content",
149			message: Message{
150				Role:    MessageRoleUser,
151				Content: []MessagePart{},
152			},
153		},
154	}
155
156	for _, tt := range tests {
157		t.Run(tt.name, func(t *testing.T) {
158			// Marshal the message
159			data, err := json.Marshal(tt.message)
160			if err != nil {
161				t.Fatalf("failed to marshal message: %v", err)
162			}
163
164			// Unmarshal back
165			var decoded Message
166			err = json.Unmarshal(data, &decoded)
167			if err != nil {
168				t.Fatalf("failed to unmarshal message: %v", err)
169			}
170
171			// Compare roles
172			if decoded.Role != tt.message.Role {
173				t.Errorf("role mismatch: got %v, want %v", decoded.Role, tt.message.Role)
174			}
175
176			// Compare content length
177			if len(decoded.Content) != len(tt.message.Content) {
178				t.Fatalf("content length mismatch: got %d, want %d", len(decoded.Content), len(tt.message.Content))
179			}
180
181			// Compare each content part
182			for i := range tt.message.Content {
183				original := tt.message.Content[i]
184				decodedPart := decoded.Content[i]
185
186				if original.GetType() != decodedPart.GetType() {
187					t.Errorf("content[%d] type mismatch: got %v, want %v", i, decodedPart.GetType(), original.GetType())
188					continue
189				}
190
191				compareMessagePart(t, i, original, decodedPart)
192			}
193		})
194	}
195}
196
197func compareMessagePart(t *testing.T, index int, original, decoded MessagePart) {
198	switch original.GetType() {
199	case ContentTypeText:
200		orig := original.(TextPart)
201		dec := decoded.(TextPart)
202		if orig.Text != dec.Text {
203			t.Errorf("content[%d] text mismatch: got %q, want %q", index, dec.Text, orig.Text)
204		}
205
206	case ContentTypeReasoning:
207		orig := original.(ReasoningPart)
208		dec := decoded.(ReasoningPart)
209		if orig.Text != dec.Text {
210			t.Errorf("content[%d] reasoning text mismatch: got %q, want %q", index, dec.Text, orig.Text)
211		}
212
213	case ContentTypeFile:
214		orig := original.(FilePart)
215		dec := decoded.(FilePart)
216		if orig.Filename != dec.Filename {
217			t.Errorf("content[%d] filename mismatch: got %q, want %q", index, dec.Filename, orig.Filename)
218		}
219		if orig.MediaType != dec.MediaType {
220			t.Errorf("content[%d] media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType)
221		}
222		if !reflect.DeepEqual(orig.Data, dec.Data) {
223			t.Errorf("content[%d] file data mismatch", index)
224		}
225
226	case ContentTypeToolCall:
227		orig := original.(ToolCallPart)
228		dec := decoded.(ToolCallPart)
229		if orig.ToolCallID != dec.ToolCallID {
230			t.Errorf("content[%d] tool call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID)
231		}
232		if orig.ToolName != dec.ToolName {
233			t.Errorf("content[%d] tool name mismatch: got %q, want %q", index, dec.ToolName, orig.ToolName)
234		}
235		if orig.Input != dec.Input {
236			t.Errorf("content[%d] tool input mismatch: got %q, want %q", index, dec.Input, orig.Input)
237		}
238		if orig.ProviderExecuted != dec.ProviderExecuted {
239			t.Errorf("content[%d] provider executed mismatch: got %v, want %v", index, dec.ProviderExecuted, orig.ProviderExecuted)
240		}
241
242	case ContentTypeToolResult:
243		orig := original.(ToolResultPart)
244		dec := decoded.(ToolResultPart)
245		if orig.ToolCallID != dec.ToolCallID {
246			t.Errorf("content[%d] tool result call id mismatch: got %q, want %q", index, dec.ToolCallID, orig.ToolCallID)
247		}
248		compareToolResultOutput(t, index, orig.Output, dec.Output)
249	}
250}
251
252func compareToolResultOutput(t *testing.T, index int, original, decoded ToolResultOutputContent) {
253	if original.GetType() != decoded.GetType() {
254		t.Errorf("content[%d] tool result output type mismatch: got %v, want %v", index, decoded.GetType(), original.GetType())
255		return
256	}
257
258	switch original.GetType() {
259	case ToolResultContentTypeText:
260		orig := original.(ToolResultOutputContentText)
261		dec := decoded.(ToolResultOutputContentText)
262		if orig.Text != dec.Text {
263			t.Errorf("content[%d] tool result text mismatch: got %q, want %q", index, dec.Text, orig.Text)
264		}
265
266	case ToolResultContentTypeError:
267		orig := original.(ToolResultOutputContentError)
268		dec := decoded.(ToolResultOutputContentError)
269		if orig.Error.Error() != dec.Error.Error() {
270			t.Errorf("content[%d] tool result error mismatch: got %q, want %q", index, dec.Error.Error(), orig.Error.Error())
271		}
272
273	case ToolResultContentTypeMedia:
274		orig := original.(ToolResultOutputContentMedia)
275		dec := decoded.(ToolResultOutputContentMedia)
276		if orig.Data != dec.Data {
277			t.Errorf("content[%d] tool result media data mismatch", index)
278		}
279		if orig.MediaType != dec.MediaType {
280			t.Errorf("content[%d] tool result media type mismatch: got %q, want %q", index, dec.MediaType, orig.MediaType)
281		}
282	}
283}
284
285func TestHelperFunctions(t *testing.T) {
286	t.Run("NewUserMessage - text only", func(t *testing.T) {
287		msg := NewUserMessage("Hello")
288
289		data, err := json.Marshal(msg)
290		if err != nil {
291			t.Fatalf("failed to marshal: %v", err)
292		}
293
294		var decoded Message
295		if err := json.Unmarshal(data, &decoded); err != nil {
296			t.Fatalf("failed to unmarshal: %v", err)
297		}
298
299		if decoded.Role != MessageRoleUser {
300			t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleUser)
301		}
302
303		if len(decoded.Content) != 1 {
304			t.Fatalf("expected 1 content part, got %d", len(decoded.Content))
305		}
306
307		textPart := decoded.Content[0].(TextPart)
308		if textPart.Text != "Hello" {
309			t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Hello")
310		}
311	})
312
313	t.Run("NewUserMessage - with files", func(t *testing.T) {
314		msg := NewUserMessage("Check this image",
315			FilePart{
316				Filename:  "image1.jpg",
317				Data:      []byte{0xFF, 0xD8, 0xFF},
318				MediaType: "image/jpeg",
319			},
320			FilePart{
321				Filename:  "image2.png",
322				Data:      []byte{0x89, 0x50, 0x4E, 0x47},
323				MediaType: "image/png",
324			},
325		)
326
327		data, err := json.Marshal(msg)
328		if err != nil {
329			t.Fatalf("failed to marshal: %v", err)
330		}
331
332		var decoded Message
333		if err := json.Unmarshal(data, &decoded); err != nil {
334			t.Fatalf("failed to unmarshal: %v", err)
335		}
336
337		if len(decoded.Content) != 3 {
338			t.Fatalf("expected 3 content parts, got %d", len(decoded.Content))
339		}
340
341		// Check text part
342		textPart := decoded.Content[0].(TextPart)
343		if textPart.Text != "Check this image" {
344			t.Errorf("text mismatch: got %q, want %q", textPart.Text, "Check this image")
345		}
346
347		// Check first file
348		file1 := decoded.Content[1].(FilePart)
349		if file1.Filename != "image1.jpg" {
350			t.Errorf("file1 name mismatch: got %q, want %q", file1.Filename, "image1.jpg")
351		}
352
353		// Check second file
354		file2 := decoded.Content[2].(FilePart)
355		if file2.Filename != "image2.png" {
356			t.Errorf("file2 name mismatch: got %q, want %q", file2.Filename, "image2.png")
357		}
358	})
359
360	t.Run("NewSystemMessage - single prompt", func(t *testing.T) {
361		msg := NewSystemMessage("You are a helpful assistant.")
362
363		data, err := json.Marshal(msg)
364		if err != nil {
365			t.Fatalf("failed to marshal: %v", err)
366		}
367
368		var decoded Message
369		if err := json.Unmarshal(data, &decoded); err != nil {
370			t.Fatalf("failed to unmarshal: %v", err)
371		}
372
373		if decoded.Role != MessageRoleSystem {
374			t.Errorf("role mismatch: got %v, want %v", decoded.Role, MessageRoleSystem)
375		}
376
377		if len(decoded.Content) != 1 {
378			t.Fatalf("expected 1 content part, got %d", len(decoded.Content))
379		}
380
381		textPart := decoded.Content[0].(TextPart)
382		if textPart.Text != "You are a helpful assistant." {
383			t.Errorf("text mismatch: got %q, want %q", textPart.Text, "You are a helpful assistant.")
384		}
385	})
386
387	t.Run("NewSystemMessage - multiple prompts", func(t *testing.T) {
388		msg := NewSystemMessage("First instruction", "Second instruction", "Third instruction")
389
390		data, err := json.Marshal(msg)
391		if err != nil {
392			t.Fatalf("failed to marshal: %v", err)
393		}
394
395		var decoded Message
396		if err := json.Unmarshal(data, &decoded); err != nil {
397			t.Fatalf("failed to unmarshal: %v", err)
398		}
399
400		if len(decoded.Content) != 3 {
401			t.Fatalf("expected 3 content parts, got %d", len(decoded.Content))
402		}
403
404		expected := []string{"First instruction", "Second instruction", "Third instruction"}
405		for i, exp := range expected {
406			textPart := decoded.Content[i].(TextPart)
407			if textPart.Text != exp {
408				t.Errorf("content[%d] text mismatch: got %q, want %q", i, textPart.Text, exp)
409			}
410		}
411	})
412}
413
414func TestEdgeCases(t *testing.T) {
415	t.Run("empty text part", func(t *testing.T) {
416		msg := Message{
417			Role: MessageRoleUser,
418			Content: []MessagePart{
419				TextPart{Text: ""},
420			},
421		}
422
423		data, err := json.Marshal(msg)
424		if err != nil {
425			t.Fatalf("failed to marshal: %v", err)
426		}
427
428		var decoded Message
429		if err := json.Unmarshal(data, &decoded); err != nil {
430			t.Fatalf("failed to unmarshal: %v", err)
431		}
432
433		textPart := decoded.Content[0].(TextPart)
434		if textPart.Text != "" {
435			t.Errorf("expected empty text, got %q", textPart.Text)
436		}
437	})
438
439	t.Run("nil error in tool result", func(t *testing.T) {
440		msg := Message{
441			Role: MessageRoleTool,
442			Content: []MessagePart{
443				ToolResultPart{
444					ToolCallID: "call_123",
445					Output: ToolResultOutputContentError{
446						Error: nil,
447					},
448				},
449			},
450		}
451
452		data, err := json.Marshal(msg)
453		if err != nil {
454			t.Fatalf("failed to marshal: %v", err)
455		}
456
457		var decoded Message
458		if err := json.Unmarshal(data, &decoded); err != nil {
459			t.Fatalf("failed to unmarshal: %v", err)
460		}
461
462		toolResult := decoded.Content[0].(ToolResultPart)
463		errorOutput := toolResult.Output.(ToolResultOutputContentError)
464		if errorOutput.Error != nil {
465			t.Errorf("expected nil error, got %v", errorOutput.Error)
466		}
467	})
468
469	t.Run("empty file data", func(t *testing.T) {
470		msg := Message{
471			Role: MessageRoleUser,
472			Content: []MessagePart{
473				FilePart{
474					Filename:  "empty.txt",
475					Data:      []byte{},
476					MediaType: "text/plain",
477				},
478			},
479		}
480
481		data, err := json.Marshal(msg)
482		if err != nil {
483			t.Fatalf("failed to marshal: %v", err)
484		}
485
486		var decoded Message
487		if err := json.Unmarshal(data, &decoded); err != nil {
488			t.Fatalf("failed to unmarshal: %v", err)
489		}
490
491		filePart := decoded.Content[0].(FilePart)
492		if len(filePart.Data) != 0 {
493			t.Errorf("expected empty data, got %d bytes", len(filePart.Data))
494		}
495	})
496
497	t.Run("unicode in text", func(t *testing.T) {
498		msg := Message{
499			Role: MessageRoleUser,
500			Content: []MessagePart{
501				TextPart{Text: "Hello ไธ–็•Œ! ๐ŸŒ ะŸั€ะธะฒะตั‚"},
502			},
503		}
504
505		data, err := json.Marshal(msg)
506		if err != nil {
507			t.Fatalf("failed to marshal: %v", err)
508		}
509
510		var decoded Message
511		if err := json.Unmarshal(data, &decoded); err != nil {
512			t.Fatalf("failed to unmarshal: %v", err)
513		}
514
515		textPart := decoded.Content[0].(TextPart)
516		if textPart.Text != "Hello ไธ–็•Œ! ๐ŸŒ ะŸั€ะธะฒะตั‚" {
517			t.Errorf("unicode text mismatch: got %q, want %q", textPart.Text, "Hello ไธ–็•Œ! ๐ŸŒ ะŸั€ะธะฒะตั‚")
518		}
519	})
520}
521
522func TestInvalidJSONHandling(t *testing.T) {
523	t.Run("unknown message part type", func(t *testing.T) {
524		invalidJSON := `{
525			"role": "user",
526			"content": [
527				{
528					"type": "unknown-type",
529					"data": {}
530				}
531			],
532			"provider_options": null
533		}`
534
535		var msg Message
536		err := json.Unmarshal([]byte(invalidJSON), &msg)
537		if err == nil {
538			t.Error("expected error for unknown message part type, got nil")
539		}
540	})
541
542	t.Run("unknown tool result output type", func(t *testing.T) {
543		invalidJSON := `{
544			"role": "tool",
545			"content": [
546				{
547					"type": "tool-result",
548					"data": {
549						"tool_call_id": "call_123",
550						"output": {
551							"type": "unknown-output-type",
552							"data": {}
553						},
554						"provider_options": null
555					}
556				}
557			],
558			"provider_options": null
559		}`
560
561		var msg Message
562		err := json.Unmarshal([]byte(invalidJSON), &msg)
563		if err == nil {
564			t.Error("expected error for unknown tool result output type, got nil")
565		}
566	})
567
568	t.Run("malformed JSON", func(t *testing.T) {
569		invalidJSON := `{"role": "user", "content": [`
570
571		var msg Message
572		err := json.Unmarshal([]byte(invalidJSON), &msg)
573		if err == nil {
574			t.Error("expected error for malformed JSON, got nil")
575		}
576	})
577}
578
579// Mock provider data for testing provider options
580type mockProviderData struct {
581	Key string `json:"key"`
582}
583
584func (m mockProviderData) Options()     {}
585func (m mockProviderData) Type() string { return "mock" }
586func (m mockProviderData) MarshalJSON() ([]byte, error) {
587	return json.Marshal(struct {
588		Type string `json:"type"`
589		mockProviderData
590	}{
591		Type:             "mock",
592		mockProviderData: m,
593	})
594}
595
596func (m *mockProviderData) UnmarshalJSON(data []byte) error {
597	var aux struct {
598		Type string `json:"type"`
599		mockProviderData
600	}
601	if err := json.Unmarshal(data, &aux); err != nil {
602		return err
603	}
604	*m = aux.mockProviderData
605	return nil
606}
607
608func TestPromptSerialization(t *testing.T) {
609	t.Run("serialize prompt (message slice)", func(t *testing.T) {
610		prompt := Prompt{
611			NewSystemMessage("You are helpful"),
612			NewUserMessage("Hello"),
613			Message{
614				Role: MessageRoleAssistant,
615				Content: []MessagePart{
616					TextPart{Text: "Hi there!"},
617				},
618			},
619		}
620
621		data, err := json.Marshal(prompt)
622		if err != nil {
623			t.Fatalf("failed to marshal prompt: %v", err)
624		}
625
626		var decoded Prompt
627		if err := json.Unmarshal(data, &decoded); err != nil {
628			t.Fatalf("failed to unmarshal prompt: %v", err)
629		}
630
631		if len(decoded) != 3 {
632			t.Fatalf("expected 3 messages, got %d", len(decoded))
633		}
634
635		if decoded[0].Role != MessageRoleSystem {
636			t.Errorf("message 0 role mismatch: got %v, want %v", decoded[0].Role, MessageRoleSystem)
637		}
638
639		if decoded[1].Role != MessageRoleUser {
640			t.Errorf("message 1 role mismatch: got %v, want %v", decoded[1].Role, MessageRoleUser)
641		}
642
643		if decoded[2].Role != MessageRoleAssistant {
644			t.Errorf("message 2 role mismatch: got %v, want %v", decoded[2].Role, MessageRoleAssistant)
645		}
646	})
647}