diff --git a/db/db.go b/db/db.go index c922094d2dae3563c97ef358a6e018182c82fd43..3b4543d62ae4639d5912eca96fcf213ea57937b1 100644 --- a/db/db.go +++ b/db/db.go @@ -391,12 +391,13 @@ const ( // CreateMessageParams contains parameters for creating a message type CreateMessageParams struct { - ConversationID string - Type MessageType - LLMData interface{} // Will be JSON marshalled - UserData interface{} // Will be JSON marshalled - UsageData interface{} // Will be JSON marshalled - DisplayData interface{} // Will be JSON marshalled, tool-specific display content + ConversationID string + Type MessageType + LLMData interface{} // Will be JSON marshalled + UserData interface{} // Will be JSON marshalled + UsageData interface{} // Will be JSON marshalled + DisplayData interface{} // Will be JSON marshalled, tool-specific display content + ExcludedFromContext bool // If true, message is stored but not sent to LLM } // CreateMessage creates a new message @@ -453,14 +454,15 @@ func (db *DB) CreateMessage(ctx context.Context, params CreateMessageParams) (*g } message, err = q.CreateMessage(ctx, generated.CreateMessageParams{ - MessageID: messageID, - ConversationID: params.ConversationID, - SequenceID: sequenceID, - Type: string(params.Type), - LlmData: llmDataJSON, - UserData: userDataJSON, - UsageData: usageDataJSON, - DisplayData: displayDataJSON, + MessageID: messageID, + ConversationID: params.ConversationID, + SequenceID: sequenceID, + Type: string(params.Type), + LlmData: llmDataJSON, + UserData: userDataJSON, + UsageData: usageDataJSON, + DisplayData: displayDataJSON, + ExcludedFromContext: params.ExcludedFromContext, }) return err }) @@ -498,6 +500,18 @@ func (db *DB) ListMessagesByConversationPaginated(ctx context.Context, conversat return messages, err } +// ListMessagesForContext retrieves messages that should be sent to the LLM (excludes excluded_from_context=true) +func (db *DB) ListMessagesForContext(ctx context.Context, conversationID string) ([]generated.Message, error) { + var messages []generated.Message + err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error { + q := generated.New(rx.Conn()) + var err error + messages, err = q.ListMessagesForContext(ctx, conversationID) + return err + }) + return messages, err +} + // ListMessagesByType retrieves messages of a specific type in a conversation func (db *DB) ListMessagesByType(ctx context.Context, conversationID string, messageType MessageType) ([]generated.Message, error) { var messages []generated.Message diff --git a/db/generated/messages.sql.go b/db/generated/messages.sql.go index f5f5f13b00933e55b18af3da2b668209f34f4ed6..75935e3470286193c94ae447bc03fafb53cf0b49 100644 --- a/db/generated/messages.sql.go +++ b/db/generated/messages.sql.go @@ -39,20 +39,21 @@ func (q *Queries) CountMessagesInConversation(ctx context.Context, conversationI } const createMessage = `-- name: CreateMessage :one -INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data) -VALUES (?, ?, ?, ?, ?, ?, ?, ?) -RETURNING message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data +INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data, excluded_from_context) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) +RETURNING message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context ` type CreateMessageParams struct { - MessageID string `json:"message_id"` - ConversationID string `json:"conversation_id"` - SequenceID int64 `json:"sequence_id"` - Type string `json:"type"` - LlmData *string `json:"llm_data"` - UserData *string `json:"user_data"` - UsageData *string `json:"usage_data"` - DisplayData *string `json:"display_data"` + MessageID string `json:"message_id"` + ConversationID string `json:"conversation_id"` + SequenceID int64 `json:"sequence_id"` + Type string `json:"type"` + LlmData *string `json:"llm_data"` + UserData *string `json:"user_data"` + UsageData *string `json:"usage_data"` + DisplayData *string `json:"display_data"` + ExcludedFromContext bool `json:"excluded_from_context"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -65,6 +66,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.UserData, arg.UsageData, arg.DisplayData, + arg.ExcludedFromContext, ) var i Message err := row.Scan( @@ -77,6 +79,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ) return i, err } @@ -102,7 +105,7 @@ func (q *Queries) DeleteMessage(ctx context.Context, messageID string) error { } const getLatestMessage = `-- name: GetLatestMessage :one -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? ORDER BY sequence_id DESC LIMIT 1 @@ -121,12 +124,13 @@ func (q *Queries) GetLatestMessage(ctx context.Context, conversationID string) ( &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ) return i, err } const getMessage = `-- name: GetMessage :one -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE message_id = ? ` @@ -143,6 +147,7 @@ func (q *Queries) GetMessage(ctx context.Context, messageID string) (Message, er &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ) return i, err } @@ -161,7 +166,7 @@ func (q *Queries) GetNextSequenceID(ctx context.Context, conversationID string) } const listMessages = `-- name: ListMessages :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? ORDER BY sequence_id ASC ` @@ -185,6 +190,7 @@ func (q *Queries) ListMessages(ctx context.Context, conversationID string) ([]Me &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } @@ -200,7 +206,7 @@ func (q *Queries) ListMessages(ctx context.Context, conversationID string) ([]Me } const listMessagesByType = `-- name: ListMessagesByType :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? AND type = ? ORDER BY sequence_id ASC ` @@ -229,6 +235,47 @@ func (q *Queries) ListMessagesByType(ctx context.Context, arg ListMessagesByType &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesForContext = `-- name: ListMessagesForContext :many +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages +WHERE conversation_id = ? AND excluded_from_context = FALSE +ORDER BY sequence_id ASC +` + +func (q *Queries) ListMessagesForContext(ctx context.Context, conversationID string) ([]Message, error) { + rows, err := q.db.QueryContext(ctx, listMessagesForContext, conversationID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Message{} + for rows.Next() { + var i Message + if err := rows.Scan( + &i.MessageID, + &i.ConversationID, + &i.SequenceID, + &i.Type, + &i.LlmData, + &i.UserData, + &i.UsageData, + &i.CreatedAt, + &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } @@ -244,7 +291,7 @@ func (q *Queries) ListMessagesByType(ctx context.Context, arg ListMessagesByType } const listMessagesPaginated = `-- name: ListMessagesPaginated :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? ORDER BY sequence_id ASC LIMIT ? OFFSET ? @@ -275,6 +322,7 @@ func (q *Queries) ListMessagesPaginated(ctx context.Context, arg ListMessagesPag &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } @@ -290,7 +338,7 @@ func (q *Queries) ListMessagesPaginated(ctx context.Context, arg ListMessagesPag } const listMessagesSince = `-- name: ListMessagesSince :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? AND sequence_id > ? ORDER BY sequence_id ASC ` @@ -319,6 +367,7 @@ func (q *Queries) ListMessagesSince(ctx context.Context, arg ListMessagesSincePa &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } diff --git a/db/generated/models.go b/db/generated/models.go index 5d66089b07b8fe84874c7ecbc9ad6c879f56cbca..b6b5094fbac6188c7eb32be3d785b2cdf26866ed 100644 --- a/db/generated/models.go +++ b/db/generated/models.go @@ -37,15 +37,16 @@ type LlmRequest struct { } type Message struct { - MessageID string `json:"message_id"` - ConversationID string `json:"conversation_id"` - SequenceID int64 `json:"sequence_id"` - Type string `json:"type"` - LlmData *string `json:"llm_data"` - UserData *string `json:"user_data"` - UsageData *string `json:"usage_data"` - CreatedAt time.Time `json:"created_at"` - DisplayData *string `json:"display_data"` + MessageID string `json:"message_id"` + ConversationID string `json:"conversation_id"` + SequenceID int64 `json:"sequence_id"` + Type string `json:"type"` + LlmData *string `json:"llm_data"` + UserData *string `json:"user_data"` + UsageData *string `json:"usage_data"` + CreatedAt time.Time `json:"created_at"` + DisplayData *string `json:"display_data"` + ExcludedFromContext bool `json:"excluded_from_context"` } type Migration struct { diff --git a/db/query/messages.sql b/db/query/messages.sql index d33eb03f306dcf3e34eb7029b300ce83a896cb18..a481fc373adb10d4f394d0d24bfcb9c54fb71e54 100644 --- a/db/query/messages.sql +++ b/db/query/messages.sql @@ -1,6 +1,6 @@ -- name: CreateMessage :one -INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data) -VALUES (?, ?, ?, ?, ?, ?, ?, ?) +INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data, excluded_from_context) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING *; -- name: GetNextSequenceID :one @@ -17,6 +17,11 @@ SELECT * FROM messages WHERE conversation_id = ? ORDER BY sequence_id ASC; +-- name: ListMessagesForContext :many +SELECT * FROM messages +WHERE conversation_id = ? AND excluded_from_context = FALSE +ORDER BY sequence_id ASC; + -- name: ListMessagesPaginated :many SELECT * FROM messages WHERE conversation_id = ? diff --git a/db/schema/014-add-excluded-from-context.sql b/db/schema/014-add-excluded-from-context.sql new file mode 100644 index 0000000000000000000000000000000000000000..afff05c1edeedc92311c056984f5bea504c3c95b --- /dev/null +++ b/db/schema/014-add-excluded-from-context.sql @@ -0,0 +1,7 @@ +-- Add excluded_from_context column to messages table. +-- Messages with this flag set are stored for billing/cost tracking purposes +-- but are NOT included when building the LLM request context. +-- This is used for truncated responses that we want to keep for cost tracking +-- but that would confuse the LLM if sent back (e.g., partial tool calls). + +ALTER TABLE messages ADD COLUMN excluded_from_context BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/llm/llm.go b/llm/llm.go index d6414a7405a4dfac81677db7a0ee5fd2872cedd8..3e2c1c21e421acc5eb423fcd276c71abf1922456 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -60,6 +60,15 @@ func EmptySchema() json.RawMessage { return MustSchema(`{"type": "object", "properties": {}}`) } +// ErrorType identifies system-generated error messages (not LLM content). +type ErrorType string + +const ( + ErrorTypeNone ErrorType = "" // Not an error + ErrorTypeTruncation ErrorType = "truncation" // Response truncated due to max tokens + ErrorTypeLLMRequest ErrorType = "llm_request" // LLM request failed +) + type Request struct { Messages []Message ToolChoice *ToolChoice @@ -73,6 +82,14 @@ type Message struct { Content []Content `json:"Content"` ToolUse *ToolUse `json:"ToolUse,omitempty"` // use to control whether/which tool to use EndOfTurn bool `json:"EndOfTurn"` // true if this message completes the agent's turn (no tool calls to make) + + // ExcludedFromContext indicates this message should be stored but not sent back to the LLM. + // Used for truncated responses we want to keep for cost tracking but that would confuse the LLM. + ExcludedFromContext bool `json:"ExcludedFromContext,omitempty"` + + // ErrorType indicates this is a system-generated error message (not LLM content). + // Empty string means not an error. Values: "truncation", "llm_request". + ErrorType ErrorType `json:"ErrorType,omitempty"` } // ToolUse represents a tool use in the message content. diff --git a/loop/loop.go b/loop/loop.go index 6d46338d55de9c0540d209370895f3eb985de418..140b484b55b8b1f790b096c87445ba143b142f2d 100644 --- a/loop/loop.go +++ b/loop/loop.go @@ -255,6 +255,7 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { }, }, EndOfTurn: true, + ErrorType: llm.ErrorTypeLLMRequest, } if recordErr := l.recordMessage(ctx, errorMessage, llm.Usage{}); recordErr != nil { l.logger.Error("failed to record error message", "error", recordErr) @@ -269,6 +270,13 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { l.totalUsage.Add(resp.Usage) l.mu.Unlock() + // Handle max tokens truncation BEFORE adding to history - truncated responses + // should not be added to history normally (they get special handling) + if resp.StopReason == llm.StopReasonMaxTokens { + l.logger.Warn("LLM response truncated due to max tokens") + return l.handleMaxTokensTruncation(ctx, resp) + } + // Convert response to message and add to history assistantMessage := resp.ToMessage() l.mu.Lock() @@ -290,12 +298,6 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { return l.handleToolCalls(ctx, resp.Content) } - // Handle max tokens truncation - record error message for the user - if resp.StopReason == llm.StopReasonMaxTokens { - l.logger.Warn("LLM response truncated due to max tokens") - return l.handleMaxTokensTruncation(ctx) - } - // End of turn - check for git state changes l.checkGitStateChange(ctx) @@ -340,11 +342,26 @@ func (l *Loop) checkGitStateChange(ctx context.Context) { } // handleMaxTokensTruncation handles the case where the LLM response was truncated -// due to hitting the maximum output token limit. It records an error message -// informing the user and instructing the LLM to use smaller outputs. -func (l *Loop) handleMaxTokensTruncation(ctx context.Context) error { +// due to hitting the maximum output token limit. It records the truncated message +// for cost tracking (excluded from context) and an error message for the user. +func (l *Loop) handleMaxTokensTruncation(ctx context.Context, resp *llm.Response) error { + // Record the truncated message for cost tracking, but mark it as excluded from context. + // This preserves billing information without confusing the LLM on future turns. + truncatedMessage := resp.ToMessage() + truncatedMessage.ExcludedFromContext = true + + // Record the truncated message with usage metadata + usageWithMeta := resp.Usage + usageWithMeta.Model = resp.Model + usageWithMeta.StartTime = resp.StartTime + usageWithMeta.EndTime = resp.EndTime + if err := l.recordMessage(ctx, truncatedMessage, usageWithMeta); err != nil { + l.logger.Error("failed to record truncated message", "error", err) + } + + // Record a truncation error message with EndOfTurn=true to properly signal end of turn. errorMessage := llm.Message{ - Role: llm.MessageRoleUser, + Role: llm.MessageRoleAssistant, Content: []llm.Content{ { Type: llm.ContentTypeText, @@ -354,13 +371,15 @@ func (l *Loop) handleMaxTokensTruncation(ctx context.Context) error { "The user can ask you to continue if needed.]", }, }, + EndOfTurn: true, + ErrorType: llm.ErrorTypeTruncation, } l.mu.Lock() l.history = append(l.history, errorMessage) l.mu.Unlock() - // Record the error message + // Record the truncation error message if err := l.recordMessage(ctx, errorMessage, llm.Usage{}); err != nil { l.logger.Error("failed to record truncation error message", "error", err) } diff --git a/loop/loop_test.go b/loop/loop_test.go index 88d015447d6ac5c14fed23cc05efec4a9430c597..692783bea4722f3c9e48ae5736c2db6312386b05 100644 --- a/loop/loop_test.go +++ b/loop/loop_test.go @@ -1797,6 +1797,97 @@ func TestHandleToolCallsWithErrorTool(t *testing.T) { } } +func TestMaxTokensTruncation(t *testing.T) { + var mu sync.Mutex + var recordedMessages []llm.Message + recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + mu.Lock() + recordedMessages = append(recordedMessages, message) + mu.Unlock() + return nil + } + + service := NewPredictableService() + loop := NewLoop(Config{ + LLM: service, + History: []llm.Message{}, + Tools: []*llm.Tool{}, + RecordMessage: recordFunc, + }) + + // Queue a user message that triggers max tokens truncation + userMessage := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "maxTokens"}}, + } + loop.QueueUserMessage(userMessage) + + // Run the loop - it should stop after handling truncation + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + err := loop.Go(ctx) + if err != context.DeadlineExceeded { + t.Errorf("expected context deadline exceeded, got %v", err) + } + + // Check recorded messages + mu.Lock() + numMessages := len(recordedMessages) + messages := make([]llm.Message, len(recordedMessages)) + copy(messages, recordedMessages) + mu.Unlock() + + // We should see two messages: + // 1. The truncated message (with ExcludedFromContext=true) for cost tracking + // 2. The truncation error message (with ErrorType=truncation) + if numMessages != 2 { + t.Errorf("Expected 2 recorded messages (truncated + error), got %d", numMessages) + for i, msg := range messages { + t.Logf("Message %d: Role=%v, EndOfTurn=%v, ExcludedFromContext=%v, ErrorType=%v", + i, msg.Role, msg.EndOfTurn, msg.ExcludedFromContext, msg.ErrorType) + } + return + } + + // First message: truncated response (for cost tracking, excluded from context) + truncatedMsg := messages[0] + if truncatedMsg.Role != llm.MessageRoleAssistant { + t.Errorf("Truncated message should be assistant, got %v", truncatedMsg.Role) + } + if !truncatedMsg.ExcludedFromContext { + t.Error("Truncated message should have ExcludedFromContext=true") + } + + // Second message: truncation error + errorMsg := messages[1] + if errorMsg.Role != llm.MessageRoleAssistant { + t.Errorf("Error message should be assistant, got %v", errorMsg.Role) + } + if !errorMsg.EndOfTurn { + t.Error("Error message should have EndOfTurn=true") + } + if errorMsg.ErrorType != llm.ErrorTypeTruncation { + t.Errorf("Error message should have ErrorType=truncation, got %v", errorMsg.ErrorType) + } + if errorMsg.ExcludedFromContext { + t.Error("Error message should not be excluded from context") + } + if !strings.Contains(errorMsg.Content[0].Text, "SYSTEM ERROR") { + t.Errorf("Error message should contain SYSTEM ERROR, got: %s", errorMsg.Content[0].Text) + } + + // Verify history contains user message + error message, but NOT the truncated response + loop.mu.Lock() + history := loop.history + loop.mu.Unlock() + + // History should have: user message + error message (the truncated response is NOT added to history) + if len(history) != 2 { + t.Errorf("History should have 2 messages (user + error), got %d", len(history)) + } +} + //func TestInsertMissingToolResultsEdgeCases(t *testing.T) { // loop := NewLoop(Config{ // LLM: NewPredictableService(), diff --git a/server/convo.go b/server/convo.go index 51d4f1e9c24530038286cba6ce4df5e0fe73085d..ec17267956d3fe8bbe5dd9e96bfc4ff2dff72160 100644 --- a/server/convo.go +++ b/server/convo.go @@ -127,7 +127,8 @@ func (cm *ConversationManager) Hydrate(ctx context.Context) error { var messages []generated.Message err = cm.db.Queries(ctx, func(q *generated.Queries) error { var err error - messages, err = q.ListMessages(ctx, cm.conversationID) + // Use ListMessagesForContext to exclude messages marked as excluded_from_context + messages, err = q.ListMessagesForContext(ctx, cm.conversationID) return err }) if err != nil { @@ -325,6 +326,12 @@ func (cm *ConversationManager) partitionMessages(messages []generated.Message) ( continue } + // Skip error messages - they are system-generated for user visibility, + // but should not be sent to the LLM as they are not part of the conversation + if msg.Type == string(db.MessageTypeError) { + continue + } + llmMsg, err := convertToLLMMessage(msg) if err != nil { cm.logger.Warn("Failed to convert message to LLM format", "messageID", msg.MessageID, "error", err) diff --git a/server/server.go b/server/server.go index ce79d9b233cb7a5814e7915646c3a9d955e3df27..01ae012cb30fd13da8f525fca2181f9a32b050f7 100644 --- a/server/server.go +++ b/server/server.go @@ -165,14 +165,14 @@ func calculateContextWindowSize(messages []APIMessage) uint64 { return 0 } -// isAgentEndOfTurn checks if a message is an agent message with end_of_turn=true. +// isAgentEndOfTurn checks if a message is an agent or error message with end_of_turn=true. // This indicates the agent loop has finished processing. func isAgentEndOfTurn(msg *generated.Message) bool { if msg == nil { return false } - // Only agent messages can have end_of_turn - if msg.Type != string(db.MessageTypeAgent) { + // Agent and error messages can have end_of_turn + if msg.Type != string(db.MessageTypeAgent) && msg.Type != string(db.MessageTypeError) { return false } if msg.LlmData == nil { @@ -599,12 +599,13 @@ func (s *Server) recordMessage(ctx context.Context, conversationID string, messa // Create message createdMsg, err := s.db.CreateMessage(ctx, db.CreateMessageParams{ - ConversationID: conversationID, - Type: messageType, - LLMData: message, - UserData: nil, - UsageData: usage, - DisplayData: displayDataToStore, + ConversationID: conversationID, + Type: messageType, + LLMData: message, + UserData: nil, + UsageData: usage, + DisplayData: displayDataToStore, + ExcludedFromContext: message.ExcludedFromContext, }) if err != nil { return fmt.Errorf("failed to create message: %w", err) @@ -635,16 +636,15 @@ func (s *Server) recordMessage(ctx context.Context, conversationID string, messa // getMessageType determines the message type from an LLM message func (s *Server) getMessageType(message llm.Message) (db.MessageType, error) { + // System-generated errors are stored as error type + if message.ErrorType != llm.ErrorTypeNone { + return db.MessageTypeError, nil + } + switch message.Role { case llm.MessageRoleUser: return db.MessageTypeUser, nil case llm.MessageRoleAssistant: - // Check if this is an error message by looking at content - for _, content := range message.Content { - if content.Type == llm.ContentTypeText && strings.HasPrefix(content.Text, "LLM request failed:") { - return db.MessageTypeError, nil - } - } return db.MessageTypeAgent, nil default: // For tool messages, check if it's a tool call or tool result