1package provider
2
3import (
4 "context"
5
6 "github.com/kujtimiihoxha/termai/internal/llm/tools"
7 "github.com/kujtimiihoxha/termai/internal/message"
8)
9
10// EventType represents the type of streaming event
11type EventType string
12
13const (
14 EventContentStart EventType = "content_start"
15 EventContentDelta EventType = "content_delta"
16 EventThinkingDelta EventType = "thinking_delta"
17 EventContentStop EventType = "content_stop"
18 EventComplete EventType = "complete"
19 EventError EventType = "error"
20 EventWarning EventType = "warning"
21 EventInfo EventType = "info"
22)
23
24type TokenUsage struct {
25 InputTokens int64
26 OutputTokens int64
27 CacheCreationTokens int64
28 CacheReadTokens int64
29}
30
31type ProviderResponse struct {
32 Content string
33 ToolCalls []message.ToolCall
34 Usage TokenUsage
35 FinishReason string
36}
37
38type ProviderEvent struct {
39 Type EventType
40 Content string
41 Thinking string
42 ToolCall *message.ToolCall
43 Error error
44 Response *ProviderResponse
45
46 // Used for giving users info on e.x retry
47 Info string
48}
49
50type Provider interface {
51 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
52
53 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
54}
55
56func cleanupMessages(messages []message.Message) []message.Message {
57 // First pass: filter out canceled messages
58 var cleanedMessages []message.Message
59 for _, msg := range messages {
60 if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
61 // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
62 // cancelled
63 cleanedMessages = append(cleanedMessages, msg)
64 }
65 }
66
67 // Second pass: filter out tool messages without a corresponding tool call
68 var result []message.Message
69 toolMessageIDs := make(map[string]bool)
70
71 for _, msg := range cleanedMessages {
72 if msg.Role == message.Assistant {
73 for _, toolCall := range msg.ToolCalls() {
74 toolMessageIDs[toolCall.ID] = true // Mark as referenced
75 }
76 }
77 }
78
79 // Keep only messages that aren't unreferenced tool messages
80 for _, msg := range cleanedMessages {
81 if msg.Role == message.Tool {
82 for _, toolCall := range msg.ToolResults() {
83 if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
84 result = append(result, msg)
85 }
86 }
87 } else {
88 result = append(result, msg)
89 }
90 }
91 return result
92}