From df2b63ffd46c7acea6b29963d2872846516b49b3 Mon Sep 17 00:00:00 2001 From: Philip Zeyliger Date: Fri, 23 Jan 2026 21:06:59 -0800 Subject: [PATCH] fix: handle truncated LLM responses correctly This is iteration on https://github.com/boldsoftware/shelley/issues/15. How to handle truncation is tricky; we observed it was still keeping the agent as "running" and was confusing with the error. Let's see how this goes. When an LLM response is truncated due to max_tokens: 1. The truncated message is preserved in the database with excluded_from_context=true for cost/billing tracking, but not sent back to the LLM (partial tool calls confuse it) 2. A system error message is recorded with ErrorType set, stored as message type 'error'. This properly signals end of turn so the UI updates the agent working state. 3. Error messages are identified by the ErrorType field on llm.Message, not by fragile text prefix matching on message content. Changes: - Add excluded_from_context column to messages table (migration 014) - Add ListMessagesForContext query to filter excluded messages - Add ErrorType field to llm.Message (truncation, llm_request) - Handle max tokens BEFORE adding response to history (prevents double recording) - Record truncated message (excluded) + error message on truncation - Check ErrorType in getMessageType() to return db.MessageTypeError - Skip error messages when building LLM context - Update isAgentEndOfTurn to handle error messages - Add test for max tokens truncation handling Co-authored-by: Shelley --- db/db.go | 42 ++++++---- db/generated/messages.sql.go | 83 +++++++++++++++---- db/generated/models.go | 19 +++-- db/query/messages.sql | 9 +- db/schema/014-add-excluded-from-context.sql | 7 ++ llm/llm.go | 17 ++++ loop/loop.go | 41 +++++++--- loop/loop_test.go | 91 +++++++++++++++++++++ server/convo.go | 9 +- server/server.go | 30 +++---- 10 files changed, 279 insertions(+), 69 deletions(-) create mode 100644 db/schema/014-add-excluded-from-context.sql diff --git a/db/db.go b/db/db.go index c922094d2dae3563c97ef358a6e018182c82fd43..3b4543d62ae4639d5912eca96fcf213ea57937b1 100644 --- a/db/db.go +++ b/db/db.go @@ -391,12 +391,13 @@ const ( // CreateMessageParams contains parameters for creating a message type CreateMessageParams struct { - ConversationID string - Type MessageType - LLMData interface{} // Will be JSON marshalled - UserData interface{} // Will be JSON marshalled - UsageData interface{} // Will be JSON marshalled - DisplayData interface{} // Will be JSON marshalled, tool-specific display content + ConversationID string + Type MessageType + LLMData interface{} // Will be JSON marshalled + UserData interface{} // Will be JSON marshalled + UsageData interface{} // Will be JSON marshalled + DisplayData interface{} // Will be JSON marshalled, tool-specific display content + ExcludedFromContext bool // If true, message is stored but not sent to LLM } // CreateMessage creates a new message @@ -453,14 +454,15 @@ func (db *DB) CreateMessage(ctx context.Context, params CreateMessageParams) (*g } message, err = q.CreateMessage(ctx, generated.CreateMessageParams{ - MessageID: messageID, - ConversationID: params.ConversationID, - SequenceID: sequenceID, - Type: string(params.Type), - LlmData: llmDataJSON, - UserData: userDataJSON, - UsageData: usageDataJSON, - DisplayData: displayDataJSON, + MessageID: messageID, + ConversationID: params.ConversationID, + SequenceID: sequenceID, + Type: string(params.Type), + LlmData: llmDataJSON, + UserData: userDataJSON, + UsageData: usageDataJSON, + DisplayData: displayDataJSON, + ExcludedFromContext: params.ExcludedFromContext, }) return err }) @@ -498,6 +500,18 @@ func (db *DB) ListMessagesByConversationPaginated(ctx context.Context, conversat return messages, err } +// ListMessagesForContext retrieves messages that should be sent to the LLM (excludes excluded_from_context=true) +func (db *DB) ListMessagesForContext(ctx context.Context, conversationID string) ([]generated.Message, error) { + var messages []generated.Message + err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error { + q := generated.New(rx.Conn()) + var err error + messages, err = q.ListMessagesForContext(ctx, conversationID) + return err + }) + return messages, err +} + // ListMessagesByType retrieves messages of a specific type in a conversation func (db *DB) ListMessagesByType(ctx context.Context, conversationID string, messageType MessageType) ([]generated.Message, error) { var messages []generated.Message diff --git a/db/generated/messages.sql.go b/db/generated/messages.sql.go index f5f5f13b00933e55b18af3da2b668209f34f4ed6..75935e3470286193c94ae447bc03fafb53cf0b49 100644 --- a/db/generated/messages.sql.go +++ b/db/generated/messages.sql.go @@ -39,20 +39,21 @@ func (q *Queries) CountMessagesInConversation(ctx context.Context, conversationI } const createMessage = `-- name: CreateMessage :one -INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data) -VALUES (?, ?, ?, ?, ?, ?, ?, ?) -RETURNING message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data +INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data, excluded_from_context) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) +RETURNING message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context ` type CreateMessageParams struct { - MessageID string `json:"message_id"` - ConversationID string `json:"conversation_id"` - SequenceID int64 `json:"sequence_id"` - Type string `json:"type"` - LlmData *string `json:"llm_data"` - UserData *string `json:"user_data"` - UsageData *string `json:"usage_data"` - DisplayData *string `json:"display_data"` + MessageID string `json:"message_id"` + ConversationID string `json:"conversation_id"` + SequenceID int64 `json:"sequence_id"` + Type string `json:"type"` + LlmData *string `json:"llm_data"` + UserData *string `json:"user_data"` + UsageData *string `json:"usage_data"` + DisplayData *string `json:"display_data"` + ExcludedFromContext bool `json:"excluded_from_context"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -65,6 +66,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.UserData, arg.UsageData, arg.DisplayData, + arg.ExcludedFromContext, ) var i Message err := row.Scan( @@ -77,6 +79,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ) return i, err } @@ -102,7 +105,7 @@ func (q *Queries) DeleteMessage(ctx context.Context, messageID string) error { } const getLatestMessage = `-- name: GetLatestMessage :one -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? ORDER BY sequence_id DESC LIMIT 1 @@ -121,12 +124,13 @@ func (q *Queries) GetLatestMessage(ctx context.Context, conversationID string) ( &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ) return i, err } const getMessage = `-- name: GetMessage :one -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE message_id = ? ` @@ -143,6 +147,7 @@ func (q *Queries) GetMessage(ctx context.Context, messageID string) (Message, er &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ) return i, err } @@ -161,7 +166,7 @@ func (q *Queries) GetNextSequenceID(ctx context.Context, conversationID string) } const listMessages = `-- name: ListMessages :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? ORDER BY sequence_id ASC ` @@ -185,6 +190,7 @@ func (q *Queries) ListMessages(ctx context.Context, conversationID string) ([]Me &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } @@ -200,7 +206,7 @@ func (q *Queries) ListMessages(ctx context.Context, conversationID string) ([]Me } const listMessagesByType = `-- name: ListMessagesByType :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? AND type = ? ORDER BY sequence_id ASC ` @@ -229,6 +235,47 @@ func (q *Queries) ListMessagesByType(ctx context.Context, arg ListMessagesByType &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listMessagesForContext = `-- name: ListMessagesForContext :many +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages +WHERE conversation_id = ? AND excluded_from_context = FALSE +ORDER BY sequence_id ASC +` + +func (q *Queries) ListMessagesForContext(ctx context.Context, conversationID string) ([]Message, error) { + rows, err := q.db.QueryContext(ctx, listMessagesForContext, conversationID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Message{} + for rows.Next() { + var i Message + if err := rows.Scan( + &i.MessageID, + &i.ConversationID, + &i.SequenceID, + &i.Type, + &i.LlmData, + &i.UserData, + &i.UsageData, + &i.CreatedAt, + &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } @@ -244,7 +291,7 @@ func (q *Queries) ListMessagesByType(ctx context.Context, arg ListMessagesByType } const listMessagesPaginated = `-- name: ListMessagesPaginated :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? ORDER BY sequence_id ASC LIMIT ? OFFSET ? @@ -275,6 +322,7 @@ func (q *Queries) ListMessagesPaginated(ctx context.Context, arg ListMessagesPag &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } @@ -290,7 +338,7 @@ func (q *Queries) ListMessagesPaginated(ctx context.Context, arg ListMessagesPag } const listMessagesSince = `-- name: ListMessagesSince :many -SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data FROM messages +SELECT message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, created_at, display_data, excluded_from_context FROM messages WHERE conversation_id = ? AND sequence_id > ? ORDER BY sequence_id ASC ` @@ -319,6 +367,7 @@ func (q *Queries) ListMessagesSince(ctx context.Context, arg ListMessagesSincePa &i.UsageData, &i.CreatedAt, &i.DisplayData, + &i.ExcludedFromContext, ); err != nil { return nil, err } diff --git a/db/generated/models.go b/db/generated/models.go index 5d66089b07b8fe84874c7ecbc9ad6c879f56cbca..b6b5094fbac6188c7eb32be3d785b2cdf26866ed 100644 --- a/db/generated/models.go +++ b/db/generated/models.go @@ -37,15 +37,16 @@ type LlmRequest struct { } type Message struct { - MessageID string `json:"message_id"` - ConversationID string `json:"conversation_id"` - SequenceID int64 `json:"sequence_id"` - Type string `json:"type"` - LlmData *string `json:"llm_data"` - UserData *string `json:"user_data"` - UsageData *string `json:"usage_data"` - CreatedAt time.Time `json:"created_at"` - DisplayData *string `json:"display_data"` + MessageID string `json:"message_id"` + ConversationID string `json:"conversation_id"` + SequenceID int64 `json:"sequence_id"` + Type string `json:"type"` + LlmData *string `json:"llm_data"` + UserData *string `json:"user_data"` + UsageData *string `json:"usage_data"` + CreatedAt time.Time `json:"created_at"` + DisplayData *string `json:"display_data"` + ExcludedFromContext bool `json:"excluded_from_context"` } type Migration struct { diff --git a/db/query/messages.sql b/db/query/messages.sql index d33eb03f306dcf3e34eb7029b300ce83a896cb18..a481fc373adb10d4f394d0d24bfcb9c54fb71e54 100644 --- a/db/query/messages.sql +++ b/db/query/messages.sql @@ -1,6 +1,6 @@ -- name: CreateMessage :one -INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data) -VALUES (?, ?, ?, ?, ?, ?, ?, ?) +INSERT INTO messages (message_id, conversation_id, sequence_id, type, llm_data, user_data, usage_data, display_data, excluded_from_context) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING *; -- name: GetNextSequenceID :one @@ -17,6 +17,11 @@ SELECT * FROM messages WHERE conversation_id = ? ORDER BY sequence_id ASC; +-- name: ListMessagesForContext :many +SELECT * FROM messages +WHERE conversation_id = ? AND excluded_from_context = FALSE +ORDER BY sequence_id ASC; + -- name: ListMessagesPaginated :many SELECT * FROM messages WHERE conversation_id = ? diff --git a/db/schema/014-add-excluded-from-context.sql b/db/schema/014-add-excluded-from-context.sql new file mode 100644 index 0000000000000000000000000000000000000000..afff05c1edeedc92311c056984f5bea504c3c95b --- /dev/null +++ b/db/schema/014-add-excluded-from-context.sql @@ -0,0 +1,7 @@ +-- Add excluded_from_context column to messages table. +-- Messages with this flag set are stored for billing/cost tracking purposes +-- but are NOT included when building the LLM request context. +-- This is used for truncated responses that we want to keep for cost tracking +-- but that would confuse the LLM if sent back (e.g., partial tool calls). + +ALTER TABLE messages ADD COLUMN excluded_from_context BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/llm/llm.go b/llm/llm.go index d6414a7405a4dfac81677db7a0ee5fd2872cedd8..3e2c1c21e421acc5eb423fcd276c71abf1922456 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -60,6 +60,15 @@ func EmptySchema() json.RawMessage { return MustSchema(`{"type": "object", "properties": {}}`) } +// ErrorType identifies system-generated error messages (not LLM content). +type ErrorType string + +const ( + ErrorTypeNone ErrorType = "" // Not an error + ErrorTypeTruncation ErrorType = "truncation" // Response truncated due to max tokens + ErrorTypeLLMRequest ErrorType = "llm_request" // LLM request failed +) + type Request struct { Messages []Message ToolChoice *ToolChoice @@ -73,6 +82,14 @@ type Message struct { Content []Content `json:"Content"` ToolUse *ToolUse `json:"ToolUse,omitempty"` // use to control whether/which tool to use EndOfTurn bool `json:"EndOfTurn"` // true if this message completes the agent's turn (no tool calls to make) + + // ExcludedFromContext indicates this message should be stored but not sent back to the LLM. + // Used for truncated responses we want to keep for cost tracking but that would confuse the LLM. + ExcludedFromContext bool `json:"ExcludedFromContext,omitempty"` + + // ErrorType indicates this is a system-generated error message (not LLM content). + // Empty string means not an error. Values: "truncation", "llm_request". + ErrorType ErrorType `json:"ErrorType,omitempty"` } // ToolUse represents a tool use in the message content. diff --git a/loop/loop.go b/loop/loop.go index 6d46338d55de9c0540d209370895f3eb985de418..140b484b55b8b1f790b096c87445ba143b142f2d 100644 --- a/loop/loop.go +++ b/loop/loop.go @@ -255,6 +255,7 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { }, }, EndOfTurn: true, + ErrorType: llm.ErrorTypeLLMRequest, } if recordErr := l.recordMessage(ctx, errorMessage, llm.Usage{}); recordErr != nil { l.logger.Error("failed to record error message", "error", recordErr) @@ -269,6 +270,13 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { l.totalUsage.Add(resp.Usage) l.mu.Unlock() + // Handle max tokens truncation BEFORE adding to history - truncated responses + // should not be added to history normally (they get special handling) + if resp.StopReason == llm.StopReasonMaxTokens { + l.logger.Warn("LLM response truncated due to max tokens") + return l.handleMaxTokensTruncation(ctx, resp) + } + // Convert response to message and add to history assistantMessage := resp.ToMessage() l.mu.Lock() @@ -290,12 +298,6 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { return l.handleToolCalls(ctx, resp.Content) } - // Handle max tokens truncation - record error message for the user - if resp.StopReason == llm.StopReasonMaxTokens { - l.logger.Warn("LLM response truncated due to max tokens") - return l.handleMaxTokensTruncation(ctx) - } - // End of turn - check for git state changes l.checkGitStateChange(ctx) @@ -340,11 +342,26 @@ func (l *Loop) checkGitStateChange(ctx context.Context) { } // handleMaxTokensTruncation handles the case where the LLM response was truncated -// due to hitting the maximum output token limit. It records an error message -// informing the user and instructing the LLM to use smaller outputs. -func (l *Loop) handleMaxTokensTruncation(ctx context.Context) error { +// due to hitting the maximum output token limit. It records the truncated message +// for cost tracking (excluded from context) and an error message for the user. +func (l *Loop) handleMaxTokensTruncation(ctx context.Context, resp *llm.Response) error { + // Record the truncated message for cost tracking, but mark it as excluded from context. + // This preserves billing information without confusing the LLM on future turns. + truncatedMessage := resp.ToMessage() + truncatedMessage.ExcludedFromContext = true + + // Record the truncated message with usage metadata + usageWithMeta := resp.Usage + usageWithMeta.Model = resp.Model + usageWithMeta.StartTime = resp.StartTime + usageWithMeta.EndTime = resp.EndTime + if err := l.recordMessage(ctx, truncatedMessage, usageWithMeta); err != nil { + l.logger.Error("failed to record truncated message", "error", err) + } + + // Record a truncation error message with EndOfTurn=true to properly signal end of turn. errorMessage := llm.Message{ - Role: llm.MessageRoleUser, + Role: llm.MessageRoleAssistant, Content: []llm.Content{ { Type: llm.ContentTypeText, @@ -354,13 +371,15 @@ func (l *Loop) handleMaxTokensTruncation(ctx context.Context) error { "The user can ask you to continue if needed.]", }, }, + EndOfTurn: true, + ErrorType: llm.ErrorTypeTruncation, } l.mu.Lock() l.history = append(l.history, errorMessage) l.mu.Unlock() - // Record the error message + // Record the truncation error message if err := l.recordMessage(ctx, errorMessage, llm.Usage{}); err != nil { l.logger.Error("failed to record truncation error message", "error", err) } diff --git a/loop/loop_test.go b/loop/loop_test.go index 88d015447d6ac5c14fed23cc05efec4a9430c597..692783bea4722f3c9e48ae5736c2db6312386b05 100644 --- a/loop/loop_test.go +++ b/loop/loop_test.go @@ -1797,6 +1797,97 @@ func TestHandleToolCallsWithErrorTool(t *testing.T) { } } +func TestMaxTokensTruncation(t *testing.T) { + var mu sync.Mutex + var recordedMessages []llm.Message + recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + mu.Lock() + recordedMessages = append(recordedMessages, message) + mu.Unlock() + return nil + } + + service := NewPredictableService() + loop := NewLoop(Config{ + LLM: service, + History: []llm.Message{}, + Tools: []*llm.Tool{}, + RecordMessage: recordFunc, + }) + + // Queue a user message that triggers max tokens truncation + userMessage := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "maxTokens"}}, + } + loop.QueueUserMessage(userMessage) + + // Run the loop - it should stop after handling truncation + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + err := loop.Go(ctx) + if err != context.DeadlineExceeded { + t.Errorf("expected context deadline exceeded, got %v", err) + } + + // Check recorded messages + mu.Lock() + numMessages := len(recordedMessages) + messages := make([]llm.Message, len(recordedMessages)) + copy(messages, recordedMessages) + mu.Unlock() + + // We should see two messages: + // 1. The truncated message (with ExcludedFromContext=true) for cost tracking + // 2. The truncation error message (with ErrorType=truncation) + if numMessages != 2 { + t.Errorf("Expected 2 recorded messages (truncated + error), got %d", numMessages) + for i, msg := range messages { + t.Logf("Message %d: Role=%v, EndOfTurn=%v, ExcludedFromContext=%v, ErrorType=%v", + i, msg.Role, msg.EndOfTurn, msg.ExcludedFromContext, msg.ErrorType) + } + return + } + + // First message: truncated response (for cost tracking, excluded from context) + truncatedMsg := messages[0] + if truncatedMsg.Role != llm.MessageRoleAssistant { + t.Errorf("Truncated message should be assistant, got %v", truncatedMsg.Role) + } + if !truncatedMsg.ExcludedFromContext { + t.Error("Truncated message should have ExcludedFromContext=true") + } + + // Second message: truncation error + errorMsg := messages[1] + if errorMsg.Role != llm.MessageRoleAssistant { + t.Errorf("Error message should be assistant, got %v", errorMsg.Role) + } + if !errorMsg.EndOfTurn { + t.Error("Error message should have EndOfTurn=true") + } + if errorMsg.ErrorType != llm.ErrorTypeTruncation { + t.Errorf("Error message should have ErrorType=truncation, got %v", errorMsg.ErrorType) + } + if errorMsg.ExcludedFromContext { + t.Error("Error message should not be excluded from context") + } + if !strings.Contains(errorMsg.Content[0].Text, "SYSTEM ERROR") { + t.Errorf("Error message should contain SYSTEM ERROR, got: %s", errorMsg.Content[0].Text) + } + + // Verify history contains user message + error message, but NOT the truncated response + loop.mu.Lock() + history := loop.history + loop.mu.Unlock() + + // History should have: user message + error message (the truncated response is NOT added to history) + if len(history) != 2 { + t.Errorf("History should have 2 messages (user + error), got %d", len(history)) + } +} + //func TestInsertMissingToolResultsEdgeCases(t *testing.T) { // loop := NewLoop(Config{ // LLM: NewPredictableService(), diff --git a/server/convo.go b/server/convo.go index 51d4f1e9c24530038286cba6ce4df5e0fe73085d..ec17267956d3fe8bbe5dd9e96bfc4ff2dff72160 100644 --- a/server/convo.go +++ b/server/convo.go @@ -127,7 +127,8 @@ func (cm *ConversationManager) Hydrate(ctx context.Context) error { var messages []generated.Message err = cm.db.Queries(ctx, func(q *generated.Queries) error { var err error - messages, err = q.ListMessages(ctx, cm.conversationID) + // Use ListMessagesForContext to exclude messages marked as excluded_from_context + messages, err = q.ListMessagesForContext(ctx, cm.conversationID) return err }) if err != nil { @@ -325,6 +326,12 @@ func (cm *ConversationManager) partitionMessages(messages []generated.Message) ( continue } + // Skip error messages - they are system-generated for user visibility, + // but should not be sent to the LLM as they are not part of the conversation + if msg.Type == string(db.MessageTypeError) { + continue + } + llmMsg, err := convertToLLMMessage(msg) if err != nil { cm.logger.Warn("Failed to convert message to LLM format", "messageID", msg.MessageID, "error", err) diff --git a/server/server.go b/server/server.go index ce79d9b233cb7a5814e7915646c3a9d955e3df27..01ae012cb30fd13da8f525fca2181f9a32b050f7 100644 --- a/server/server.go +++ b/server/server.go @@ -165,14 +165,14 @@ func calculateContextWindowSize(messages []APIMessage) uint64 { return 0 } -// isAgentEndOfTurn checks if a message is an agent message with end_of_turn=true. +// isAgentEndOfTurn checks if a message is an agent or error message with end_of_turn=true. // This indicates the agent loop has finished processing. func isAgentEndOfTurn(msg *generated.Message) bool { if msg == nil { return false } - // Only agent messages can have end_of_turn - if msg.Type != string(db.MessageTypeAgent) { + // Agent and error messages can have end_of_turn + if msg.Type != string(db.MessageTypeAgent) && msg.Type != string(db.MessageTypeError) { return false } if msg.LlmData == nil { @@ -599,12 +599,13 @@ func (s *Server) recordMessage(ctx context.Context, conversationID string, messa // Create message createdMsg, err := s.db.CreateMessage(ctx, db.CreateMessageParams{ - ConversationID: conversationID, - Type: messageType, - LLMData: message, - UserData: nil, - UsageData: usage, - DisplayData: displayDataToStore, + ConversationID: conversationID, + Type: messageType, + LLMData: message, + UserData: nil, + UsageData: usage, + DisplayData: displayDataToStore, + ExcludedFromContext: message.ExcludedFromContext, }) if err != nil { return fmt.Errorf("failed to create message: %w", err) @@ -635,16 +636,15 @@ func (s *Server) recordMessage(ctx context.Context, conversationID string, messa // getMessageType determines the message type from an LLM message func (s *Server) getMessageType(message llm.Message) (db.MessageType, error) { + // System-generated errors are stored as error type + if message.ErrorType != llm.ErrorTypeNone { + return db.MessageTypeError, nil + } + switch message.Role { case llm.MessageRoleUser: return db.MessageTypeUser, nil case llm.MessageRoleAssistant: - // Check if this is an error message by looking at content - for _, content := range message.Content { - if content.Type == llm.ContentTypeText && strings.HasPrefix(content.Text, "LLM request failed:") { - return db.MessageTypeError, nil - } - } return db.MessageTypeAgent, nil default: // For tool messages, check if it's a tool call or tool result