messages_test.go

  1package db
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8	"testing"
  9	"time"
 10
 11	"shelley.exe.dev/db/generated"
 12)
 13
 14func TestMessageService_Create(t *testing.T) {
 15	db := setupTestDB(t)
 16	defer db.Close()
 17
 18	// Using db directly instead of service
 19	// Using db directly instead of service
 20	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 21	defer cancel()
 22
 23	// Create a test conversation
 24	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
 25	if err != nil {
 26		t.Fatalf("Failed to create test conversation: %v", err)
 27	}
 28
 29	tests := []struct {
 30		name      string
 31		msgType   MessageType
 32		llmData   interface{}
 33		userData  interface{}
 34		usageData interface{}
 35	}{
 36		{
 37			name:      "user message with data",
 38			msgType:   MessageTypeUser,
 39			llmData:   map[string]string{"content": "Hello, AI!"},
 40			userData:  map[string]string{"display": "Hello, AI!"},
 41			usageData: nil,
 42		},
 43		{
 44			name:      "agent message with usage",
 45			msgType:   MessageTypeAgent,
 46			llmData:   map[string]string{"response": "Hello, human!"},
 47			userData:  map[string]string{"formatted": "Hello, human!"},
 48			usageData: map[string]int{"tokens": 42},
 49		},
 50		{
 51			name:      "tool message minimal",
 52			msgType:   MessageTypeTool,
 53			llmData:   nil,
 54			userData:  nil,
 55			usageData: nil,
 56		},
 57	}
 58
 59	for _, tt := range tests {
 60		t.Run(tt.name, func(t *testing.T) {
 61			msg, err := db.CreateMessage(ctx, CreateMessageParams{
 62				ConversationID: conv.ConversationID,
 63				Type:           tt.msgType,
 64				LLMData:        tt.llmData,
 65				UserData:       tt.userData,
 66				UsageData:      tt.usageData,
 67			})
 68			if err != nil {
 69				t.Errorf("Create() error = %v", err)
 70				return
 71			}
 72
 73			if msg.MessageID == "" {
 74				t.Error("Expected non-empty message ID")
 75			}
 76
 77			if msg.ConversationID != conv.ConversationID {
 78				t.Errorf("Expected conversation ID %s, got %s", conv.ConversationID, msg.ConversationID)
 79			}
 80
 81			if msg.Type != string(tt.msgType) {
 82				t.Errorf("Expected message type %s, got %s", tt.msgType, msg.Type)
 83			}
 84
 85			// Test JSON data marshalling
 86			if tt.llmData != nil {
 87				if msg.LlmData == nil {
 88					t.Error("Expected LLM data to be non-nil")
 89				} else {
 90					var unmarshalled map[string]interface{}
 91					err := json.Unmarshal([]byte(*msg.LlmData), &unmarshalled)
 92					if err != nil {
 93						t.Errorf("Failed to unmarshal LLM data: %v", err)
 94					}
 95				}
 96			} else {
 97				if msg.LlmData != nil {
 98					t.Error("Expected LLM data to be nil")
 99				}
100			}
101
102			if msg.CreatedAt.IsZero() {
103				t.Error("Expected non-zero created_at time")
104			}
105		})
106	}
107}
108
109func TestMessageService_GetByID(t *testing.T) {
110	db := setupTestDB(t)
111	defer db.Close()
112
113	// Using db directly instead of service
114	// Using db directly instead of service
115	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
116	defer cancel()
117
118	// Create a test conversation
119	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
120	if err != nil {
121		t.Fatalf("Failed to create test conversation: %v", err)
122	}
123
124	// Create a test message
125	created, err := db.CreateMessage(ctx, CreateMessageParams{
126		ConversationID: conv.ConversationID,
127		Type:           MessageTypeUser,
128		LLMData:        map[string]string{"content": "test message"},
129	})
130	if err != nil {
131		t.Fatalf("Failed to create test message: %v", err)
132	}
133
134	// Test getting existing message
135	msg, err := db.GetMessageByID(ctx, created.MessageID)
136	if err != nil {
137		t.Errorf("GetByID() error = %v", err)
138		return
139	}
140
141	if msg.MessageID != created.MessageID {
142		t.Errorf("Expected message ID %s, got %s", created.MessageID, msg.MessageID)
143	}
144
145	// Test getting non-existent message
146	_, err = db.GetMessageByID(ctx, "non-existent")
147	if err == nil {
148		t.Error("Expected error for non-existent message")
149	}
150	if !strings.Contains(err.Error(), "not found") {
151		t.Errorf("Expected 'not found' in error message, got: %v", err)
152	}
153}
154
155func TestMessageService_ListByConversation(t *testing.T) {
156	db := setupTestDB(t)
157	defer db.Close()
158
159	// Using db directly instead of service
160	// Using db directly instead of service
161	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
162	defer cancel()
163
164	// Create a test conversation
165	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
166	if err != nil {
167		t.Fatalf("Failed to create test conversation: %v", err)
168	}
169
170	// Create multiple test messages
171	msgTypes := []MessageType{MessageTypeUser, MessageTypeAgent, MessageTypeTool}
172	for i, msgType := range msgTypes {
173		_, err := db.CreateMessage(ctx, CreateMessageParams{
174			ConversationID: conv.ConversationID,
175			Type:           msgType,
176			LLMData:        map[string]interface{}{"index": i, "type": string(msgType)},
177		})
178		if err != nil {
179			t.Fatalf("Failed to create test message %d: %v", i, err)
180		}
181	}
182
183	// List messages
184	var messages []generated.Message
185	err = db.Queries(ctx, func(q *generated.Queries) error {
186		var err error
187		messages, err = q.ListMessages(ctx, conv.ConversationID)
188		return err
189	})
190	if err != nil {
191		t.Errorf("ListByConversation() error = %v", err)
192		return
193	}
194
195	if len(messages) != 3 {
196		t.Errorf("Expected 3 messages, got %d", len(messages))
197	}
198
199	// Messages should be ordered by created_at ASC (oldest first) by the query
200	// We verify this by checking the message types are in the order we created them
201	expectedTypes := []string{"user", "agent", "tool"}
202	for i, msg := range messages {
203		if msg.Type != expectedTypes[i] {
204			t.Errorf("Expected message %d to be type %s, got %s", i, expectedTypes[i], msg.Type)
205		}
206	}
207}
208
209func TestMessageService_ListByType(t *testing.T) {
210	db := setupTestDB(t)
211	defer db.Close()
212
213	// Using db directly instead of service
214	// Using db directly instead of service
215	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
216	defer cancel()
217
218	// Create a test conversation
219	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
220	if err != nil {
221		t.Fatalf("Failed to create test conversation: %v", err)
222	}
223
224	// Create messages of different types
225	msgTypes := []MessageType{MessageTypeUser, MessageTypeAgent, MessageTypeUser, MessageTypeTool}
226	for i, msgType := range msgTypes {
227		_, err := db.CreateMessage(ctx, CreateMessageParams{
228			ConversationID: conv.ConversationID,
229			Type:           msgType,
230			LLMData:        map[string]interface{}{"index": i},
231		})
232		if err != nil {
233			t.Fatalf("Failed to create test message %d: %v", i, err)
234		}
235	}
236
237	// List only user messages
238	userMessages, err := db.ListMessagesByType(ctx, conv.ConversationID, MessageTypeUser)
239	if err != nil {
240		t.Errorf("ListByType() error = %v", err)
241		return
242	}
243
244	if len(userMessages) != 2 {
245		t.Errorf("Expected 2 user messages, got %d", len(userMessages))
246	}
247
248	// Verify all messages are user type
249	for _, msg := range userMessages {
250		if msg.Type != string(MessageTypeUser) {
251			t.Errorf("Expected user message, got %s", msg.Type)
252		}
253	}
254}
255
256func TestMessageService_GetLatest(t *testing.T) {
257	db := setupTestDB(t)
258	defer db.Close()
259
260	// Using db directly instead of service
261	// Using db directly instead of service
262	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
263	defer cancel()
264
265	// Create a test conversation
266	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
267	if err != nil {
268		t.Fatalf("Failed to create test conversation: %v", err)
269	}
270
271	// Test getting latest from empty conversation
272	_, err = db.GetLatestMessage(ctx, conv.ConversationID)
273	if err == nil {
274		t.Error("Expected error for conversation with no messages")
275	}
276
277	// Create multiple test messages
278	var lastCreated *generated.Message
279	for i := 0; i < 3; i++ {
280		created, err := db.CreateMessage(ctx, CreateMessageParams{
281			ConversationID: conv.ConversationID,
282			Type:           MessageTypeUser,
283			LLMData:        map[string]interface{}{"index": i},
284		})
285		if err != nil {
286			t.Fatalf("Failed to create test message %d: %v", i, err)
287		}
288		lastCreated = created
289	}
290
291	// Get the latest message
292	latest, err := db.GetLatestMessage(ctx, conv.ConversationID)
293	if err != nil {
294		t.Errorf("GetLatest() error = %v", err)
295		return
296	}
297
298	if latest.MessageID != lastCreated.MessageID {
299		t.Errorf("Expected latest message ID %s, got %s", lastCreated.MessageID, latest.MessageID)
300	}
301}
302
303func TestMessageService_Delete(t *testing.T) {
304	db := setupTestDB(t)
305	defer db.Close()
306
307	// Using db directly instead of service
308	// Using db directly instead of service
309	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
310	defer cancel()
311
312	// Create a test conversation
313	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
314	if err != nil {
315		t.Fatalf("Failed to create test conversation: %v", err)
316	}
317
318	// Create a test message
319	created, err := db.CreateMessage(ctx, CreateMessageParams{
320		ConversationID: conv.ConversationID,
321		Type:           MessageTypeUser,
322		LLMData:        map[string]string{"content": "test message"},
323	})
324	if err != nil {
325		t.Fatalf("Failed to create test message: %v", err)
326	}
327
328	// Delete the message
329	err = db.QueriesTx(ctx, func(q *generated.Queries) error {
330		return q.DeleteMessage(ctx, created.MessageID)
331	})
332	if err != nil {
333		t.Errorf("Delete() error = %v", err)
334		return
335	}
336
337	// Verify it's gone
338	_, err = db.GetMessageByID(ctx, created.MessageID)
339	if err == nil {
340		t.Error("Expected error when getting deleted message")
341	}
342}
343
344func TestMessageService_CountInConversation(t *testing.T) {
345	db := setupTestDB(t)
346	defer db.Close()
347
348	// Using db directly instead of service
349	// Using db directly instead of service
350	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
351	defer cancel()
352
353	// Create a test conversation
354	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
355	if err != nil {
356		t.Fatalf("Failed to create test conversation: %v", err)
357	}
358
359	// Initial count should be 0
360	var count int64
361	err = db.Queries(ctx, func(q *generated.Queries) error {
362		var err error
363		count, err = q.CountMessagesInConversation(ctx, conv.ConversationID)
364		return err
365	})
366	if err != nil {
367		t.Errorf("CountInConversation() error = %v", err)
368		return
369	}
370	if count != 0 {
371		t.Errorf("Expected initial count 0, got %d", count)
372	}
373
374	// Create test messages
375	for i := 0; i < 4; i++ {
376		_, err := db.CreateMessage(ctx, CreateMessageParams{
377			ConversationID: conv.ConversationID,
378			Type:           MessageTypeUser,
379			LLMData:        map[string]interface{}{"index": i},
380		})
381		if err != nil {
382			t.Fatalf("Failed to create test message %d: %v", i, err)
383		}
384	}
385
386	// Count should now be 4
387	err = db.Queries(ctx, func(q *generated.Queries) error {
388		var err error
389		count, err = q.CountMessagesInConversation(ctx, conv.ConversationID)
390		return err
391	})
392	if err != nil {
393		t.Errorf("CountInConversation() error = %v", err)
394		return
395	}
396	if count != 4 {
397		t.Errorf("Expected count 4, got %d", count)
398	}
399}
400
401func TestMessageService_CountByType(t *testing.T) {
402	db := setupTestDB(t)
403	defer db.Close()
404
405	// Using db directly instead of service
406	// Using db directly instead of service
407	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
408	defer cancel()
409
410	// Create a test conversation
411	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
412	if err != nil {
413		t.Fatalf("Failed to create test conversation: %v", err)
414	}
415
416	// Create messages of different types
417	msgTypes := []MessageType{MessageTypeUser, MessageTypeAgent, MessageTypeUser, MessageTypeTool, MessageTypeUser}
418	for i, msgType := range msgTypes {
419		_, err := db.CreateMessage(ctx, CreateMessageParams{
420			ConversationID: conv.ConversationID,
421			Type:           msgType,
422			LLMData:        map[string]interface{}{"index": i},
423		})
424		if err != nil {
425			t.Fatalf("Failed to create test message %d: %v", i, err)
426		}
427	}
428
429	// Count user messages (should be 3)
430	userCount, err := db.CountMessagesByType(ctx, conv.ConversationID, MessageTypeUser)
431	if err != nil {
432		t.Errorf("CountByType() error = %v", err)
433		return
434	}
435	if userCount != 3 {
436		t.Errorf("Expected 3 user messages, got %d", userCount)
437	}
438
439	// Count agent messages (should be 1)
440	agentCount, err := db.CountMessagesByType(ctx, conv.ConversationID, MessageTypeAgent)
441	if err != nil {
442		t.Errorf("CountByType() error = %v", err)
443		return
444	}
445	if agentCount != 1 {
446		t.Errorf("Expected 1 agent message, got %d", agentCount)
447	}
448
449	// Count tool messages (should be 1)
450	toolCount, err := db.CountMessagesByType(ctx, conv.ConversationID, MessageTypeTool)
451	if err != nil {
452		t.Errorf("CountByType() error = %v", err)
453		return
454	}
455	if toolCount != 1 {
456		t.Errorf("Expected 1 tool message, got %d", toolCount)
457	}
458}
459
460func TestMessageService_ListMessagesByConversationPaginated(t *testing.T) {
461	db := setupTestDB(t)
462	defer db.Close()
463
464	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
465	defer cancel()
466
467	// Create a test conversation
468	conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-paginated"), true, nil, nil)
469	if err != nil {
470		t.Fatalf("Failed to create test conversation: %v", err)
471	}
472
473	// Create multiple test messages
474	for i := 0; i < 5; i++ {
475		_, err := db.CreateMessage(ctx, CreateMessageParams{
476			ConversationID: conv.ConversationID,
477			Type:           MessageTypeUser,
478			LLMData:        map[string]string{"text": fmt.Sprintf("test message %d", i)},
479		})
480		if err != nil {
481			t.Fatalf("Failed to create test message %d: %v", i, err)
482		}
483	}
484
485	// Test ListMessagesByConversationPaginated with limit and offset
486	messages, err := db.ListMessagesByConversationPaginated(ctx, conv.ConversationID, 3, 0)
487	if err != nil {
488		t.Errorf("ListMessagesByConversationPaginated() error = %v", err)
489	}
490
491	if len(messages) != 3 {
492		t.Errorf("Expected 3 messages, got %d", len(messages))
493	}
494
495	// Test with offset
496	messages2, err := db.ListMessagesByConversationPaginated(ctx, conv.ConversationID, 3, 3)
497	if err != nil {
498		t.Errorf("ListMessagesByConversationPaginated() with offset error = %v", err)
499	}
500
501	if len(messages2) != 2 {
502		t.Errorf("Expected 2 messages with offset, got %d", len(messages2))
503	}
504
505	// Verify no duplicate messages between pages
506	messageIDs := make(map[string]bool)
507	for _, msg := range messages {
508		if messageIDs[msg.MessageID] {
509			t.Error("Found duplicate message ID in first page")
510		}
511		messageIDs[msg.MessageID] = true
512	}
513
514	for _, msg := range messages2 {
515		if messageIDs[msg.MessageID] {
516			t.Error("Found duplicate message ID in second page")
517		}
518		messageIDs[msg.MessageID] = true
519	}
520}