fix cancell

Kujtim Hoxha created

Change summary

internal/llm/provider/anthropic.go |  2 +
internal/llm/provider/gemini.go    |  2 +
internal/llm/provider/openai.go    |  3 +
internal/llm/provider/provider.go  | 36 ++++++++++++++++++++++++++++++++
4 files changed, 42 insertions(+), 1 deletion(-)

Detailed changes

internal/llm/provider/anthropic.go 🔗

@@ -92,6 +92,7 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 }
 
 func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+	messages = cleanupMessages(messages)
 	anthropicMessages := a.convertToAnthropicMessages(messages)
 	anthropicTools := a.convertToAnthropicTools(tools)
 
@@ -135,6 +136,7 @@ func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message
 }
 
 func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
+	messages = cleanupMessages(messages)
 	anthropicMessages := a.convertToAnthropicMessages(messages)
 	anthropicTools := a.convertToAnthropicTools(tools)
 

internal/llm/provider/gemini.go 🔗

@@ -154,6 +154,7 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse)
 }
 
 func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+	messages = cleanupMessages(messages)
 	model := p.client.GenerativeModel(p.model.APIModel)
 	model.SetMaxOutputTokens(p.maxTokens)
 
@@ -206,6 +207,7 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
 }
 
 func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
+	messages = cleanupMessages(messages)
 	model := p.client.GenerativeModel(p.model.APIModel)
 	model.SetMaxOutputTokens(p.maxTokens)
 

internal/llm/provider/openai.go 🔗

@@ -163,6 +163,7 @@ func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUs
 }
 
 func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+	messages = cleanupMessages(messages)
 	chatMessages := p.convertToOpenAIMessages(messages)
 	openaiTools := p.convertToOpenAITools(tools)
 
@@ -206,6 +207,7 @@ func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Me
 }
 
 func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
+	messages = cleanupMessages(messages)
 	chatMessages := p.convertToOpenAIMessages(messages)
 	openaiTools := p.convertToOpenAITools(tools)
 
@@ -276,4 +278,3 @@ func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.
 
 	return eventChan, nil
 }
-

internal/llm/provider/provider.go 🔗

@@ -52,3 +52,39 @@ type Provider interface {
 
 	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
 }
+
+func cleanupMessages(messages []message.Message) []message.Message {
+	// First pass: filter out canceled messages
+	var cleanedMessages []message.Message
+	for _, msg := range messages {
+		if msg.FinishReason() != "canceled" {
+			cleanedMessages = append(cleanedMessages, msg)
+		}
+	}
+
+	// Second pass: filter out tool messages without a corresponding tool call
+	var result []message.Message
+	toolMessageIDs := make(map[string]bool)
+
+	for _, msg := range cleanedMessages {
+		if msg.Role == message.Assistant {
+			for _, toolCall := range msg.ToolCalls() {
+				toolMessageIDs[toolCall.ID] = true // Mark as referenced
+			}
+		}
+	}
+
+	// Keep only messages that aren't unreferenced tool messages
+	for _, msg := range cleanedMessages {
+		if msg.Role == message.Tool {
+			for _, toolCall := range msg.ToolResults() {
+				if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
+					result = append(result, msg)
+				}
+			}
+		} else {
+			result = append(result, msg)
+		}
+	}
+	return result
+}