1package db
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "strings"
8 "testing"
9 "time"
10
11 "shelley.exe.dev/db/generated"
12)
13
14func TestMessageService_Create(t *testing.T) {
15 db := setupTestDB(t)
16 defer db.Close()
17
18 // Using db directly instead of service
19 // Using db directly instead of service
20 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
21 defer cancel()
22
23 // Create a test conversation
24 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
25 if err != nil {
26 t.Fatalf("Failed to create test conversation: %v", err)
27 }
28
29 tests := []struct {
30 name string
31 msgType MessageType
32 llmData interface{}
33 userData interface{}
34 usageData interface{}
35 }{
36 {
37 name: "user message with data",
38 msgType: MessageTypeUser,
39 llmData: map[string]string{"content": "Hello, AI!"},
40 userData: map[string]string{"display": "Hello, AI!"},
41 usageData: nil,
42 },
43 {
44 name: "agent message with usage",
45 msgType: MessageTypeAgent,
46 llmData: map[string]string{"response": "Hello, human!"},
47 userData: map[string]string{"formatted": "Hello, human!"},
48 usageData: map[string]int{"tokens": 42},
49 },
50 {
51 name: "tool message minimal",
52 msgType: MessageTypeTool,
53 llmData: nil,
54 userData: nil,
55 usageData: nil,
56 },
57 }
58
59 for _, tt := range tests {
60 t.Run(tt.name, func(t *testing.T) {
61 msg, err := db.CreateMessage(ctx, CreateMessageParams{
62 ConversationID: conv.ConversationID,
63 Type: tt.msgType,
64 LLMData: tt.llmData,
65 UserData: tt.userData,
66 UsageData: tt.usageData,
67 })
68 if err != nil {
69 t.Errorf("Create() error = %v", err)
70 return
71 }
72
73 if msg.MessageID == "" {
74 t.Error("Expected non-empty message ID")
75 }
76
77 if msg.ConversationID != conv.ConversationID {
78 t.Errorf("Expected conversation ID %s, got %s", conv.ConversationID, msg.ConversationID)
79 }
80
81 if msg.Type != string(tt.msgType) {
82 t.Errorf("Expected message type %s, got %s", tt.msgType, msg.Type)
83 }
84
85 // Test JSON data marshalling
86 if tt.llmData != nil {
87 if msg.LlmData == nil {
88 t.Error("Expected LLM data to be non-nil")
89 } else {
90 var unmarshalled map[string]interface{}
91 err := json.Unmarshal([]byte(*msg.LlmData), &unmarshalled)
92 if err != nil {
93 t.Errorf("Failed to unmarshal LLM data: %v", err)
94 }
95 }
96 } else {
97 if msg.LlmData != nil {
98 t.Error("Expected LLM data to be nil")
99 }
100 }
101
102 if msg.CreatedAt.IsZero() {
103 t.Error("Expected non-zero created_at time")
104 }
105 })
106 }
107}
108
109func TestMessageService_GetByID(t *testing.T) {
110 db := setupTestDB(t)
111 defer db.Close()
112
113 // Using db directly instead of service
114 // Using db directly instead of service
115 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
116 defer cancel()
117
118 // Create a test conversation
119 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
120 if err != nil {
121 t.Fatalf("Failed to create test conversation: %v", err)
122 }
123
124 // Create a test message
125 created, err := db.CreateMessage(ctx, CreateMessageParams{
126 ConversationID: conv.ConversationID,
127 Type: MessageTypeUser,
128 LLMData: map[string]string{"content": "test message"},
129 })
130 if err != nil {
131 t.Fatalf("Failed to create test message: %v", err)
132 }
133
134 // Test getting existing message
135 msg, err := db.GetMessageByID(ctx, created.MessageID)
136 if err != nil {
137 t.Errorf("GetByID() error = %v", err)
138 return
139 }
140
141 if msg.MessageID != created.MessageID {
142 t.Errorf("Expected message ID %s, got %s", created.MessageID, msg.MessageID)
143 }
144
145 // Test getting non-existent message
146 _, err = db.GetMessageByID(ctx, "non-existent")
147 if err == nil {
148 t.Error("Expected error for non-existent message")
149 }
150 if !strings.Contains(err.Error(), "not found") {
151 t.Errorf("Expected 'not found' in error message, got: %v", err)
152 }
153}
154
155func TestMessageService_ListByConversation(t *testing.T) {
156 db := setupTestDB(t)
157 defer db.Close()
158
159 // Using db directly instead of service
160 // Using db directly instead of service
161 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
162 defer cancel()
163
164 // Create a test conversation
165 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
166 if err != nil {
167 t.Fatalf("Failed to create test conversation: %v", err)
168 }
169
170 // Create multiple test messages
171 msgTypes := []MessageType{MessageTypeUser, MessageTypeAgent, MessageTypeTool}
172 for i, msgType := range msgTypes {
173 _, err := db.CreateMessage(ctx, CreateMessageParams{
174 ConversationID: conv.ConversationID,
175 Type: msgType,
176 LLMData: map[string]interface{}{"index": i, "type": string(msgType)},
177 })
178 if err != nil {
179 t.Fatalf("Failed to create test message %d: %v", i, err)
180 }
181 }
182
183 // List messages
184 var messages []generated.Message
185 err = db.Queries(ctx, func(q *generated.Queries) error {
186 var err error
187 messages, err = q.ListMessages(ctx, conv.ConversationID)
188 return err
189 })
190 if err != nil {
191 t.Errorf("ListByConversation() error = %v", err)
192 return
193 }
194
195 if len(messages) != 3 {
196 t.Errorf("Expected 3 messages, got %d", len(messages))
197 }
198
199 // Messages should be ordered by created_at ASC (oldest first) by the query
200 // We verify this by checking the message types are in the order we created them
201 expectedTypes := []string{"user", "agent", "tool"}
202 for i, msg := range messages {
203 if msg.Type != expectedTypes[i] {
204 t.Errorf("Expected message %d to be type %s, got %s", i, expectedTypes[i], msg.Type)
205 }
206 }
207}
208
209func TestMessageService_ListByType(t *testing.T) {
210 db := setupTestDB(t)
211 defer db.Close()
212
213 // Using db directly instead of service
214 // Using db directly instead of service
215 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
216 defer cancel()
217
218 // Create a test conversation
219 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
220 if err != nil {
221 t.Fatalf("Failed to create test conversation: %v", err)
222 }
223
224 // Create messages of different types
225 msgTypes := []MessageType{MessageTypeUser, MessageTypeAgent, MessageTypeUser, MessageTypeTool}
226 for i, msgType := range msgTypes {
227 _, err := db.CreateMessage(ctx, CreateMessageParams{
228 ConversationID: conv.ConversationID,
229 Type: msgType,
230 LLMData: map[string]interface{}{"index": i},
231 })
232 if err != nil {
233 t.Fatalf("Failed to create test message %d: %v", i, err)
234 }
235 }
236
237 // List only user messages
238 userMessages, err := db.ListMessagesByType(ctx, conv.ConversationID, MessageTypeUser)
239 if err != nil {
240 t.Errorf("ListByType() error = %v", err)
241 return
242 }
243
244 if len(userMessages) != 2 {
245 t.Errorf("Expected 2 user messages, got %d", len(userMessages))
246 }
247
248 // Verify all messages are user type
249 for _, msg := range userMessages {
250 if msg.Type != string(MessageTypeUser) {
251 t.Errorf("Expected user message, got %s", msg.Type)
252 }
253 }
254}
255
256func TestMessageService_GetLatest(t *testing.T) {
257 db := setupTestDB(t)
258 defer db.Close()
259
260 // Using db directly instead of service
261 // Using db directly instead of service
262 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
263 defer cancel()
264
265 // Create a test conversation
266 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
267 if err != nil {
268 t.Fatalf("Failed to create test conversation: %v", err)
269 }
270
271 // Test getting latest from empty conversation
272 _, err = db.GetLatestMessage(ctx, conv.ConversationID)
273 if err == nil {
274 t.Error("Expected error for conversation with no messages")
275 }
276
277 // Create multiple test messages
278 var lastCreated *generated.Message
279 for i := 0; i < 3; i++ {
280 created, err := db.CreateMessage(ctx, CreateMessageParams{
281 ConversationID: conv.ConversationID,
282 Type: MessageTypeUser,
283 LLMData: map[string]interface{}{"index": i},
284 })
285 if err != nil {
286 t.Fatalf("Failed to create test message %d: %v", i, err)
287 }
288 lastCreated = created
289 }
290
291 // Get the latest message
292 latest, err := db.GetLatestMessage(ctx, conv.ConversationID)
293 if err != nil {
294 t.Errorf("GetLatest() error = %v", err)
295 return
296 }
297
298 if latest.MessageID != lastCreated.MessageID {
299 t.Errorf("Expected latest message ID %s, got %s", lastCreated.MessageID, latest.MessageID)
300 }
301}
302
303func TestMessageService_Delete(t *testing.T) {
304 db := setupTestDB(t)
305 defer db.Close()
306
307 // Using db directly instead of service
308 // Using db directly instead of service
309 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
310 defer cancel()
311
312 // Create a test conversation
313 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
314 if err != nil {
315 t.Fatalf("Failed to create test conversation: %v", err)
316 }
317
318 // Create a test message
319 created, err := db.CreateMessage(ctx, CreateMessageParams{
320 ConversationID: conv.ConversationID,
321 Type: MessageTypeUser,
322 LLMData: map[string]string{"content": "test message"},
323 })
324 if err != nil {
325 t.Fatalf("Failed to create test message: %v", err)
326 }
327
328 // Delete the message
329 err = db.QueriesTx(ctx, func(q *generated.Queries) error {
330 return q.DeleteMessage(ctx, created.MessageID)
331 })
332 if err != nil {
333 t.Errorf("Delete() error = %v", err)
334 return
335 }
336
337 // Verify it's gone
338 _, err = db.GetMessageByID(ctx, created.MessageID)
339 if err == nil {
340 t.Error("Expected error when getting deleted message")
341 }
342}
343
344func TestMessageService_CountInConversation(t *testing.T) {
345 db := setupTestDB(t)
346 defer db.Close()
347
348 // Using db directly instead of service
349 // Using db directly instead of service
350 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
351 defer cancel()
352
353 // Create a test conversation
354 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
355 if err != nil {
356 t.Fatalf("Failed to create test conversation: %v", err)
357 }
358
359 // Initial count should be 0
360 var count int64
361 err = db.Queries(ctx, func(q *generated.Queries) error {
362 var err error
363 count, err = q.CountMessagesInConversation(ctx, conv.ConversationID)
364 return err
365 })
366 if err != nil {
367 t.Errorf("CountInConversation() error = %v", err)
368 return
369 }
370 if count != 0 {
371 t.Errorf("Expected initial count 0, got %d", count)
372 }
373
374 // Create test messages
375 for i := 0; i < 4; i++ {
376 _, err := db.CreateMessage(ctx, CreateMessageParams{
377 ConversationID: conv.ConversationID,
378 Type: MessageTypeUser,
379 LLMData: map[string]interface{}{"index": i},
380 })
381 if err != nil {
382 t.Fatalf("Failed to create test message %d: %v", i, err)
383 }
384 }
385
386 // Count should now be 4
387 err = db.Queries(ctx, func(q *generated.Queries) error {
388 var err error
389 count, err = q.CountMessagesInConversation(ctx, conv.ConversationID)
390 return err
391 })
392 if err != nil {
393 t.Errorf("CountInConversation() error = %v", err)
394 return
395 }
396 if count != 4 {
397 t.Errorf("Expected count 4, got %d", count)
398 }
399}
400
401func TestMessageService_CountByType(t *testing.T) {
402 db := setupTestDB(t)
403 defer db.Close()
404
405 // Using db directly instead of service
406 // Using db directly instead of service
407 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
408 defer cancel()
409
410 // Create a test conversation
411 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
412 if err != nil {
413 t.Fatalf("Failed to create test conversation: %v", err)
414 }
415
416 // Create messages of different types
417 msgTypes := []MessageType{MessageTypeUser, MessageTypeAgent, MessageTypeUser, MessageTypeTool, MessageTypeUser}
418 for i, msgType := range msgTypes {
419 _, err := db.CreateMessage(ctx, CreateMessageParams{
420 ConversationID: conv.ConversationID,
421 Type: msgType,
422 LLMData: map[string]interface{}{"index": i},
423 })
424 if err != nil {
425 t.Fatalf("Failed to create test message %d: %v", i, err)
426 }
427 }
428
429 // Count user messages (should be 3)
430 userCount, err := db.CountMessagesByType(ctx, conv.ConversationID, MessageTypeUser)
431 if err != nil {
432 t.Errorf("CountByType() error = %v", err)
433 return
434 }
435 if userCount != 3 {
436 t.Errorf("Expected 3 user messages, got %d", userCount)
437 }
438
439 // Count agent messages (should be 1)
440 agentCount, err := db.CountMessagesByType(ctx, conv.ConversationID, MessageTypeAgent)
441 if err != nil {
442 t.Errorf("CountByType() error = %v", err)
443 return
444 }
445 if agentCount != 1 {
446 t.Errorf("Expected 1 agent message, got %d", agentCount)
447 }
448
449 // Count tool messages (should be 1)
450 toolCount, err := db.CountMessagesByType(ctx, conv.ConversationID, MessageTypeTool)
451 if err != nil {
452 t.Errorf("CountByType() error = %v", err)
453 return
454 }
455 if toolCount != 1 {
456 t.Errorf("Expected 1 tool message, got %d", toolCount)
457 }
458}
459
460func TestMessageService_ListMessagesByConversationPaginated(t *testing.T) {
461 db := setupTestDB(t)
462 defer db.Close()
463
464 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
465 defer cancel()
466
467 // Create a test conversation
468 conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-paginated"), true, nil, nil)
469 if err != nil {
470 t.Fatalf("Failed to create test conversation: %v", err)
471 }
472
473 // Create multiple test messages
474 for i := 0; i < 5; i++ {
475 _, err := db.CreateMessage(ctx, CreateMessageParams{
476 ConversationID: conv.ConversationID,
477 Type: MessageTypeUser,
478 LLMData: map[string]string{"text": fmt.Sprintf("test message %d", i)},
479 })
480 if err != nil {
481 t.Fatalf("Failed to create test message %d: %v", i, err)
482 }
483 }
484
485 // Test ListMessagesByConversationPaginated with limit and offset
486 messages, err := db.ListMessagesByConversationPaginated(ctx, conv.ConversationID, 3, 0)
487 if err != nil {
488 t.Errorf("ListMessagesByConversationPaginated() error = %v", err)
489 }
490
491 if len(messages) != 3 {
492 t.Errorf("Expected 3 messages, got %d", len(messages))
493 }
494
495 // Test with offset
496 messages2, err := db.ListMessagesByConversationPaginated(ctx, conv.ConversationID, 3, 3)
497 if err != nil {
498 t.Errorf("ListMessagesByConversationPaginated() with offset error = %v", err)
499 }
500
501 if len(messages2) != 2 {
502 t.Errorf("Expected 2 messages with offset, got %d", len(messages2))
503 }
504
505 // Verify no duplicate messages between pages
506 messageIDs := make(map[string]bool)
507 for _, msg := range messages {
508 if messageIDs[msg.MessageID] {
509 t.Error("Found duplicate message ID in first page")
510 }
511 messageIDs[msg.MessageID] = true
512 }
513
514 for _, msg := range messages2 {
515 if messageIDs[msg.MessageID] {
516 t.Error("Found duplicate message ID in second page")
517 }
518 messageIDs[msg.MessageID] = true
519 }
520}