1package main
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ai"
10 "github.com/charmbracelet/crush/internal/ai/providers"
11 "github.com/charmbracelet/crush/internal/llm/tools"
12)
13
14// WeatherTool is a simple tool that simulates weather lookup
15type WeatherTool struct{}
16
17func (w *WeatherTool) Info() tools.ToolInfo {
18 return tools.ToolInfo{
19 Name: "get_weather",
20 Description: "Get the current weather for a specific location",
21 Parameters: map[string]any{
22 "location": map[string]any{
23 "type": "string",
24 "description": "The city and country, e.g. 'London, UK'",
25 },
26 "unit": map[string]any{
27 "type": "string",
28 "description": "Temperature unit (celsius or fahrenheit)",
29 "enum": []string{"celsius", "fahrenheit"},
30 "default": "celsius",
31 },
32 },
33 Required: []string{"location"},
34 }
35}
36
37func (w *WeatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
38 // Simulate weather lookup with some fake data
39 location := "Unknown"
40 if strings.Contains(params.Input, "pristina") || strings.Contains(params.Input, "Pristina") {
41 location = "Pristina, Kosovo"
42 } else if strings.Contains(params.Input, "london") || strings.Contains(params.Input, "London") {
43 location = "London, UK"
44 } else if strings.Contains(params.Input, "new york") || strings.Contains(params.Input, "New York") {
45 location = "New York, USA"
46 }
47
48 unit := "celsius"
49 if strings.Contains(params.Input, "fahrenheit") {
50 unit = "fahrenheit"
51 }
52
53 var temp string
54 if unit == "fahrenheit" {
55 temp = "72°F"
56 } else {
57 temp = "22°C"
58 }
59
60 weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
61 return tools.NewTextResponse(weather), nil
62}
63
64// CalculatorTool demonstrates a second tool for multi-tool scenarios
65type CalculatorTool struct{}
66
67func (c *CalculatorTool) Info() tools.ToolInfo {
68 return tools.ToolInfo{
69 Name: "calculate",
70 Description: "Perform basic mathematical calculations",
71 Parameters: map[string]any{
72 "expression": map[string]any{
73 "type": "string",
74 "description": "Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')",
75 },
76 },
77 Required: []string{"expression"},
78 }
79}
80
81func (c *CalculatorTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
82 // Simple calculator simulation
83 expr := strings.ReplaceAll(params.Input, "\"", "")
84 if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
85 return tools.NewTextResponse("2 + 2 = 4"), nil
86 } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
87 return tools.NewTextResponse("10 * 5 = 50"), nil
88 }
89 return tools.NewTextResponse("I can calculate simple expressions like '2 + 2' or '10 * 5'"), nil
90}
91
92func main() {
93 // Check for API key
94 apiKey := os.Getenv("OPENAI_API_KEY")
95 if apiKey == "" {
96 fmt.Println("❌ Please set OPENAI_API_KEY environment variable")
97 fmt.Println(" export OPENAI_API_KEY=your_api_key_here")
98 os.Exit(1)
99 }
100
101 fmt.Println("🚀 Streaming Agent Example")
102 fmt.Println("==========================")
103 fmt.Println()
104
105 // Create OpenAI provider and model
106 provider := providers.NewOpenAIProvider(
107 providers.WithOpenAIApiKey(apiKey),
108 )
109 model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses
110
111 // Create agent with tools
112 agent := ai.NewAgent(
113 model,
114 ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
115 ai.WithTools(&WeatherTool{}, &CalculatorTool{}),
116 )
117
118 ctx := context.Background()
119
120 // Demonstrate streaming with comprehensive callbacks
121 fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
122 fmt.Println()
123
124 // Track streaming events
125 var stepCount int
126 var textBuffer strings.Builder
127 var reasoningBuffer strings.Builder
128
129 // Create streaming call with all callbacks
130 streamCall := ai.AgentStreamCall{
131 Prompt: "What's the weather in Pristina and what's 2 + 2?",
132
133 // Agent-level callbacks
134 OnAgentStart: func() {
135 fmt.Println("🎬 Agent started")
136 },
137 OnAgentFinish: func(result *ai.AgentResult) {
138 fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
139 },
140 OnStepStart: func(stepNumber int) {
141 stepCount++
142 fmt.Printf("📝 Step %d started\n", stepNumber+1)
143 },
144 OnStepFinish: func(stepResult ai.StepResult) {
145 fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
146 },
147 OnFinish: func(result *ai.AgentResult) {
148 fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
149 },
150 OnError: func(err error) {
151 fmt.Printf("❌ Error: %v\n", err)
152 },
153
154 // Stream part callbacks
155 OnWarnings: func(warnings []ai.CallWarning) {
156 for _, warning := range warnings {
157 fmt.Printf("⚠️ Warning: %s\n", warning.Message)
158 }
159 },
160 OnTextStart: func(id string) {
161 fmt.Print("💭 Assistant: ")
162 },
163 OnTextDelta: func(id, text string) {
164 fmt.Print(text)
165 textBuffer.WriteString(text)
166 },
167 OnTextEnd: func(id string) {
168 fmt.Println()
169 },
170 OnReasoningStart: func(id string) {
171 fmt.Print("🤔 Thinking: ")
172 },
173 OnReasoningDelta: func(id, text string) {
174 reasoningBuffer.WriteString(text)
175 },
176 OnReasoningEnd: func(id string) {
177 if reasoningBuffer.Len() > 0 {
178 fmt.Printf("%s\n", reasoningBuffer.String())
179 reasoningBuffer.Reset()
180 }
181 },
182 OnToolInputStart: func(id, toolName string) {
183 fmt.Printf("🔧 Calling tool: %s\n", toolName)
184 },
185 OnToolInputDelta: func(id, delta string) {
186 // Could show tool input being built, but it's often noisy
187 },
188 OnToolInputEnd: func(id string) {
189 // Tool input complete
190 },
191 OnToolCall: func(toolCall ai.ToolCallContent) {
192 fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName)
193 fmt.Printf(" Input: %s\n", toolCall.Input)
194 },
195 OnToolResult: func(result ai.ToolResultContent) {
196 fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
197 switch output := result.Result.(type) {
198 case ai.ToolResultOutputContentText:
199 fmt.Printf(" %s\n", output.Text)
200 case ai.ToolResultOutputContentError:
201 fmt.Printf(" Error: %s\n", output.Error.Error())
202 }
203 },
204 OnSource: func(source ai.SourceContent) {
205 fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
206 },
207 OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) {
208 fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
209 },
210 OnStreamError: func(err error) {
211 fmt.Printf("💥 Stream error: %v\n", err)
212 },
213 }
214
215 // Execute streaming agent
216 result, err := agent.Stream(ctx, streamCall)
217 if err != nil {
218 fmt.Printf("❌ Agent failed: %v\n", err)
219 os.Exit(1)
220 }
221
222 // Display final results
223 fmt.Println()
224 fmt.Println("📋 Final Summary")
225 fmt.Println("================")
226 fmt.Printf("Steps executed: %d\n", len(result.Steps))
227 fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
228 result.TotalUsage.TotalTokens,
229 result.TotalUsage.InputTokens,
230 result.TotalUsage.OutputTokens)
231
232 if result.TotalUsage.ReasoningTokens > 0 {
233 fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
234 }
235
236 fmt.Printf("Final response: %s\n", result.Response.Content.Text())
237
238 // Show step details
239 fmt.Println()
240 fmt.Println("🔍 Step Details")
241 fmt.Println("===============")
242 for i, step := range result.Steps {
243 fmt.Printf("Step %d:\n", i+1)
244 fmt.Printf(" Finish reason: %s\n", step.FinishReason)
245 fmt.Printf(" Content types: ")
246
247 var contentTypes []string
248 for _, content := range step.Content {
249 contentTypes = append(contentTypes, string(content.GetType()))
250 }
251 fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
252
253 // Show tool calls and results
254 toolCalls := step.Content.ToolCalls()
255 if len(toolCalls) > 0 {
256 fmt.Printf(" Tool calls: ")
257 var toolNames []string
258 for _, tc := range toolCalls {
259 toolNames = append(toolNames, tc.ToolName)
260 }
261 fmt.Printf("%s\n", strings.Join(toolNames, ", "))
262 }
263
264 toolResults := step.Content.ToolResults()
265 if len(toolResults) > 0 {
266 fmt.Printf(" Tool results: %d\n", len(toolResults))
267 }
268
269 fmt.Printf(" Tokens: %d\n", step.Usage.TotalTokens)
270 fmt.Println()
271 }
272
273 fmt.Println("✨ Example completed successfully!")
274}