1package main
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ai"
10 "github.com/charmbracelet/crush/internal/ai/providers"
11)
12
13func main() {
14 // Check for API key
15 apiKey := os.Getenv("ANTHROPIC_API_KEY")
16 if apiKey == "" {
17 fmt.Println("❌ Please set ANTHROPIC_API_KEY environment variable")
18 fmt.Println(" export ANTHROPIC_API_KEY=your_api_key_here")
19 os.Exit(1)
20 }
21
22 fmt.Println("🚀 Streaming Agent Example")
23 fmt.Println("==========================")
24 fmt.Println()
25
26 // Create OpenAI provider and model
27 provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
28 model, err := provider.LanguageModel("claude-sonnet-4-20250514")
29 if err != nil {
30 fmt.Println(err)
31 return
32 }
33
34 // Define input types for type-safe tools
35 type WeatherInput struct {
36 Location string `json:"location" description:"The city and country, e.g. 'London, UK'"`
37 Unit string `json:"unit,omitempty" enum:"celsius,fahrenheit" description:"Temperature unit (celsius or fahrenheit)"`
38 }
39
40 type CalculatorInput struct {
41 Expression string `json:"expression" description:"Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')"`
42 }
43
44 // Create weather tool using the new type-safe API
45 weatherTool := ai.NewAgentTool(
46 "get_weather",
47 "Get the current weather for a specific location",
48 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
49 // Simulate weather lookup with some fake data
50 location := input.Location
51 if location == "" {
52 location = "Unknown"
53 }
54
55 // Default to celsius if not specified
56 unit := input.Unit
57 if unit == "" {
58 unit = "celsius"
59 }
60
61 // Simulate different temperatures for different cities
62 var temp string
63 if strings.Contains(strings.ToLower(location), "pristina") {
64 temp = "15°C"
65 if unit == "fahrenheit" {
66 temp = "59°F"
67 }
68 } else if strings.Contains(strings.ToLower(location), "london") {
69 temp = "12°C"
70 if unit == "fahrenheit" {
71 temp = "54°F"
72 }
73 } else {
74 temp = "22°C"
75 if unit == "fahrenheit" {
76 temp = "72°F"
77 }
78 }
79
80 weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
81 return ai.NewTextResponse(weather), nil
82 },
83 )
84
85 // Create calculator tool using the new type-safe API
86 calculatorTool := ai.NewAgentTool(
87 "calculate",
88 "Perform basic mathematical calculations",
89 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
90 // Simple calculator simulation
91 expr := strings.TrimSpace(input.Expression)
92 if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
93 return ai.NewTextResponse("2 + 2 = 4"), nil
94 } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
95 return ai.NewTextResponse("10 * 5 = 50"), nil
96 } else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") {
97 return ai.NewTextResponse("15 + 27 = 42"), nil
98 }
99 return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil
100 },
101 )
102
103 // Create agent with tools
104 agent := ai.NewAgent(
105 model,
106 ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
107 ai.WithTools(weatherTool, calculatorTool),
108 )
109
110 ctx := context.Background()
111
112 // Demonstrate streaming with comprehensive callbacks
113 fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
114 fmt.Println()
115
116 // Track streaming events
117 var stepCount int
118 var textBuffer strings.Builder
119 var reasoningBuffer strings.Builder
120
121 // Create streaming call with all callbacks
122 streamCall := ai.AgentStreamCall{
123 Prompt: "What's the weather in Pristina and what's 2 + 2?",
124
125 // Agent-level callbacks
126 OnAgentStart: func() {
127 fmt.Println("🎬 Agent started")
128 },
129 OnAgentFinish: func(result *ai.AgentResult) {
130 fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
131 },
132 OnStepStart: func(stepNumber int) {
133 stepCount++
134 fmt.Printf("📝 Step %d started\n", stepNumber+1)
135 },
136 OnStepFinish: func(stepResult ai.StepResult) {
137 fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
138 },
139 OnFinish: func(result *ai.AgentResult) {
140 fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
141 },
142 OnError: func(err error) {
143 fmt.Printf("❌ Error: %v\n", err)
144 },
145
146 // Stream part callbacks
147 OnWarnings: func(warnings []ai.CallWarning) error {
148 for _, warning := range warnings {
149 fmt.Printf("⚠️ Warning: %s\n", warning.Message)
150 }
151 return nil
152 },
153 OnTextStart: func(id string) error {
154 fmt.Print("💭 Assistant: ")
155 return nil
156 },
157 OnTextDelta: func(id, text string) error {
158 fmt.Print(text)
159 textBuffer.WriteString(text)
160 return nil
161 },
162 OnTextEnd: func(id string) error {
163 fmt.Println()
164 return nil
165 },
166 OnReasoningStart: func(id string) error {
167 fmt.Print("🤔 Thinking: ")
168 return nil
169 },
170 OnReasoningDelta: func(id, text string) error {
171 reasoningBuffer.WriteString(text)
172 return nil
173 },
174 OnReasoningEnd: func(id string, content ai.ReasoningContent) error {
175 if reasoningBuffer.Len() > 0 {
176 fmt.Printf("%s\n", reasoningBuffer.String())
177 reasoningBuffer.Reset()
178 }
179 return nil
180 },
181 OnToolInputStart: func(id, toolName string) error {
182 fmt.Printf("🔧 Calling tool: %s\n", toolName)
183 return nil
184 },
185 OnToolInputDelta: func(id, delta string) error {
186 // Could show tool input being built, but it's often noisy
187 return nil
188 },
189 OnToolInputEnd: func(id string) error {
190 // Tool input complete
191 return nil
192 },
193 OnToolCall: func(toolCall ai.ToolCallContent) error {
194 fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName)
195 fmt.Printf(" Input: %s\n", toolCall.Input)
196 return nil
197 },
198 OnToolResult: func(result ai.ToolResultContent) error {
199 fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
200 switch output := result.Result.(type) {
201 case ai.ToolResultOutputContentText:
202 fmt.Printf(" %s\n", output.Text)
203 case ai.ToolResultOutputContentError:
204 fmt.Printf(" Error: %s\n", output.Error.Error())
205 }
206 return nil
207 },
208 OnSource: func(source ai.SourceContent) error {
209 fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
210 return nil
211 },
212 OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) error {
213 fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
214 return nil
215 },
216 }
217
218 // Execute streaming agent
219 result, err := agent.Stream(ctx, streamCall)
220 if err != nil {
221 fmt.Printf("❌ Agent failed: %v\n", err)
222 os.Exit(1)
223 }
224
225 // Display final results
226 fmt.Println()
227 fmt.Println("📋 Final Summary")
228 fmt.Println("================")
229 fmt.Printf("Steps executed: %d\n", len(result.Steps))
230 fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
231 result.TotalUsage.TotalTokens,
232 result.TotalUsage.InputTokens,
233 result.TotalUsage.OutputTokens)
234
235 if result.TotalUsage.ReasoningTokens > 0 {
236 fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
237 }
238
239 fmt.Printf("Final response: %s\n", result.Response.Content.Text())
240
241 // Show step details
242 fmt.Println()
243 fmt.Println("🔍 Step Details")
244 fmt.Println("===============")
245 for i, step := range result.Steps {
246 fmt.Printf("Step %d:\n", i+1)
247 fmt.Printf(" Finish reason: %s\n", step.FinishReason)
248 fmt.Printf(" Content types: ")
249
250 var contentTypes []string
251 for _, content := range step.Content {
252 contentTypes = append(contentTypes, string(content.GetType()))
253 }
254 fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
255
256 // Show tool calls and results
257 toolCalls := step.Content.ToolCalls()
258 if len(toolCalls) > 0 {
259 fmt.Printf(" Tool calls: ")
260 var toolNames []string
261 for _, tc := range toolCalls {
262 toolNames = append(toolNames, tc.ToolName)
263 }
264 fmt.Printf("%s\n", strings.Join(toolNames, ", "))
265 }
266
267 toolResults := step.Content.ToolResults()
268 if len(toolResults) > 0 {
269 fmt.Printf(" Tool results: %d\n", len(toolResults))
270 }
271
272 fmt.Printf(" Tokens: %d\n", step.Usage.TotalTokens)
273 fmt.Println()
274 }
275
276 fmt.Println("✨ Example completed successfully!")
277}