1package main
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ai"
10 "github.com/charmbracelet/crush/internal/ai/providers"
11)
12
13func main() {
14 // Check for API key
15 apiKey := os.Getenv("ANTHROPIC_API_KEY")
16 if apiKey == "" {
17 fmt.Println("❌ Please set ANTHROPIC_API_KEY environment variable")
18 fmt.Println(" export ANTHROPIC_API_KEY=your_api_key_here")
19 os.Exit(1)
20 }
21
22 fmt.Println("🚀 Streaming Agent Example")
23 fmt.Println("==========================")
24 fmt.Println()
25
26 // Create OpenAI provider and model
27 provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
28 model, err := provider.LanguageModel("claude-sonnet-4-20250514")
29 if err != nil {
30 fmt.Println(err)
31 return
32 }
33
34 // Define input types for type-safe tools
35 type WeatherInput struct {
36 Location string `json:"location" description:"The city and country, e.g. 'London, UK'"`
37 Unit string `json:"unit,omitempty" enum:"celsius,fahrenheit" description:"Temperature unit (celsius or fahrenheit)"`
38 }
39
40 type CalculatorInput struct {
41 Expression string `json:"expression" description:"Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')"`
42 }
43
44 // Create weather tool using the new type-safe API
45 weatherTool := ai.NewAgentTool(
46 "get_weather",
47 "Get the current weather for a specific location",
48 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
49 // Simulate weather lookup with some fake data
50 location := input.Location
51 if location == "" {
52 location = "Unknown"
53 }
54
55 // Default to celsius if not specified
56 unit := input.Unit
57 if unit == "" {
58 unit = "celsius"
59 }
60
61 // Simulate different temperatures for different cities
62 var temp string
63 if strings.Contains(strings.ToLower(location), "pristina") {
64 temp = "15°C"
65 if unit == "fahrenheit" {
66 temp = "59°F"
67 }
68 } else if strings.Contains(strings.ToLower(location), "london") {
69 temp = "12°C"
70 if unit == "fahrenheit" {
71 temp = "54°F"
72 }
73 } else {
74 temp = "22°C"
75 if unit == "fahrenheit" {
76 temp = "72°F"
77 }
78 }
79
80 weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
81 return ai.NewTextResponse(weather), nil
82 },
83 )
84
85 // Create calculator tool using the new type-safe API
86 calculatorTool := ai.NewAgentTool(
87 "calculate",
88 "Perform basic mathematical calculations",
89 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
90 // Simple calculator simulation
91 expr := strings.TrimSpace(input.Expression)
92 if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
93 return ai.NewTextResponse("2 + 2 = 4"), nil
94 } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
95 return ai.NewTextResponse("10 * 5 = 50"), nil
96 } else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") {
97 return ai.NewTextResponse("15 + 27 = 42"), nil
98 }
99 return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil
100 },
101 )
102
103 // Create agent with tools
104 agent := ai.NewAgent(
105 model,
106 ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
107 ai.WithTools(weatherTool, calculatorTool),
108 )
109
110 ctx := context.Background()
111
112 // Demonstrate streaming with comprehensive callbacks
113 fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
114 fmt.Println()
115
116 // Track streaming events
117 var stepCount int
118 var textBuffer strings.Builder
119 var reasoningBuffer strings.Builder
120
121 // Create streaming call with all callbacks
122 streamCall := ai.AgentStreamCall{
123 Prompt: "What's the weather in Pristina and what's 2 + 2?",
124
125 // Agent-level callbacks
126 OnAgentStart: func() {
127 fmt.Println("🎬 Agent started")
128 },
129 OnAgentFinish: func(result *ai.AgentResult) {
130 fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
131 },
132 OnStepStart: func(stepNumber int) {
133 stepCount++
134 fmt.Printf("📝 Step %d started\n", stepNumber+1)
135 },
136 OnStepFinish: func(stepResult ai.StepResult) {
137 fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
138 },
139 OnFinish: func(result *ai.AgentResult) {
140 fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
141 },
142 OnError: func(err error) {
143 fmt.Printf("❌ Error: %v\n", err)
144 },
145
146 // Stream part callbacks
147 OnWarnings: func(warnings []ai.CallWarning) {
148 for _, warning := range warnings {
149 fmt.Printf("⚠️ Warning: %s\n", warning.Message)
150 }
151 },
152 OnTextStart: func(id string) {
153 fmt.Print("💭 Assistant: ")
154 },
155 OnTextDelta: func(id, text string) {
156 fmt.Print(text)
157 textBuffer.WriteString(text)
158 },
159 OnTextEnd: func(id string) {
160 fmt.Println()
161 },
162 OnReasoningStart: func(id string) {
163 fmt.Print("🤔 Thinking: ")
164 },
165 OnReasoningDelta: func(id, text string) {
166 reasoningBuffer.WriteString(text)
167 },
168 OnReasoningEnd: func(id string, content ai.ReasoningContent) {
169 if reasoningBuffer.Len() > 0 {
170 fmt.Printf("%s\n", reasoningBuffer.String())
171 reasoningBuffer.Reset()
172 }
173 },
174 OnToolInputStart: func(id, toolName string) {
175 fmt.Printf("🔧 Calling tool: %s\n", toolName)
176 },
177 OnToolInputDelta: func(id, delta string) {
178 // Could show tool input being built, but it's often noisy
179 },
180 OnToolInputEnd: func(id string) {
181 // Tool input complete
182 },
183 OnToolCall: func(toolCall ai.ToolCallContent) {
184 fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName)
185 fmt.Printf(" Input: %s\n", toolCall.Input)
186 },
187 OnToolResult: func(result ai.ToolResultContent) {
188 fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
189 switch output := result.Result.(type) {
190 case ai.ToolResultOutputContentText:
191 fmt.Printf(" %s\n", output.Text)
192 case ai.ToolResultOutputContentError:
193 fmt.Printf(" Error: %s\n", output.Error.Error())
194 }
195 },
196 OnSource: func(source ai.SourceContent) {
197 fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
198 },
199 OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) {
200 fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
201 },
202 OnStreamError: func(err error) {
203 fmt.Printf("💥 Stream error: %v\n", err)
204 },
205 }
206
207 // Execute streaming agent
208 result, err := agent.Stream(ctx, streamCall)
209 if err != nil {
210 fmt.Printf("❌ Agent failed: %v\n", err)
211 os.Exit(1)
212 }
213
214 // Display final results
215 fmt.Println()
216 fmt.Println("📋 Final Summary")
217 fmt.Println("================")
218 fmt.Printf("Steps executed: %d\n", len(result.Steps))
219 fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
220 result.TotalUsage.TotalTokens,
221 result.TotalUsage.InputTokens,
222 result.TotalUsage.OutputTokens)
223
224 if result.TotalUsage.ReasoningTokens > 0 {
225 fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
226 }
227
228 fmt.Printf("Final response: %s\n", result.Response.Content.Text())
229
230 // Show step details
231 fmt.Println()
232 fmt.Println("🔍 Step Details")
233 fmt.Println("===============")
234 for i, step := range result.Steps {
235 fmt.Printf("Step %d:\n", i+1)
236 fmt.Printf(" Finish reason: %s\n", step.FinishReason)
237 fmt.Printf(" Content types: ")
238
239 var contentTypes []string
240 for _, content := range step.Content {
241 contentTypes = append(contentTypes, string(content.GetType()))
242 }
243 fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
244
245 // Show tool calls and results
246 toolCalls := step.Content.ToolCalls()
247 if len(toolCalls) > 0 {
248 fmt.Printf(" Tool calls: ")
249 var toolNames []string
250 for _, tc := range toolCalls {
251 toolNames = append(toolNames, tc.ToolName)
252 }
253 fmt.Printf("%s\n", strings.Join(toolNames, ", "))
254 }
255
256 toolResults := step.Content.ToolResults()
257 if len(toolResults) > 0 {
258 fmt.Printf(" Tool results: %d\n", len(toolResults))
259 }
260
261 fmt.Printf(" Tokens: %d\n", step.Usage.TotalTokens)
262 fmt.Println()
263 }
264
265 fmt.Println("✨ Example completed successfully!")
266}