1package main
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ai"
10 "github.com/charmbracelet/crush/internal/ai/providers"
11)
12
13func main() {
14 // Check for API key
15 apiKey := os.Getenv("OPENAI_API_KEY")
16 if apiKey == "" {
17 fmt.Println("❌ Please set OPENAI_API_KEY environment variable")
18 fmt.Println(" export OPENAI_API_KEY=your_api_key_here")
19 os.Exit(1)
20 }
21
22 fmt.Println("🚀 Streaming Agent Example")
23 fmt.Println("==========================")
24 fmt.Println()
25
26 // Create OpenAI provider and model
27 provider := providers.NewOpenAIProvider(
28 providers.WithOpenAIApiKey(apiKey),
29 )
30 model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses
31
32 // Define input types for type-safe tools
33 type WeatherInput struct {
34 Location string `json:"location" description:"The city and country, e.g. 'London, UK'"`
35 Unit string `json:"unit,omitempty" enum:"celsius,fahrenheit" description:"Temperature unit (celsius or fahrenheit)"`
36 }
37
38 type CalculatorInput struct {
39 Expression string `json:"expression" description:"Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')"`
40 }
41
42 // Create weather tool using the new type-safe API
43 weatherTool := ai.NewTypedToolFunc(
44 "get_weather",
45 "Get the current weather for a specific location",
46 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
47 // Simulate weather lookup with some fake data
48 location := input.Location
49 if location == "" {
50 location = "Unknown"
51 }
52
53 // Default to celsius if not specified
54 unit := input.Unit
55 if unit == "" {
56 unit = "celsius"
57 }
58
59 // Simulate different temperatures for different cities
60 var temp string
61 if strings.Contains(strings.ToLower(location), "pristina") {
62 temp = "15°C"
63 if unit == "fahrenheit" {
64 temp = "59°F"
65 }
66 } else if strings.Contains(strings.ToLower(location), "london") {
67 temp = "12°C"
68 if unit == "fahrenheit" {
69 temp = "54°F"
70 }
71 } else {
72 temp = "22°C"
73 if unit == "fahrenheit" {
74 temp = "72°F"
75 }
76 }
77
78 weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
79 return ai.NewTextResponse(weather), nil
80 },
81 )
82
83 // Create calculator tool using the new type-safe API
84 calculatorTool := ai.NewTypedToolFunc(
85 "calculate",
86 "Perform basic mathematical calculations",
87 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
88 // Simple calculator simulation
89 expr := strings.TrimSpace(input.Expression)
90 if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
91 return ai.NewTextResponse("2 + 2 = 4"), nil
92 } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
93 return ai.NewTextResponse("10 * 5 = 50"), nil
94 } else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") {
95 return ai.NewTextResponse("15 + 27 = 42"), nil
96 }
97 return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil
98 },
99 )
100
101 // Create agent with tools
102 agent := ai.NewAgent(
103 model,
104 ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
105 ai.WithTools(weatherTool, calculatorTool),
106 )
107
108 ctx := context.Background()
109
110 // Demonstrate streaming with comprehensive callbacks
111 fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
112 fmt.Println()
113
114 // Track streaming events
115 var stepCount int
116 var textBuffer strings.Builder
117 var reasoningBuffer strings.Builder
118
119 // Create streaming call with all callbacks
120 streamCall := ai.AgentStreamCall{
121 Prompt: "What's the weather in Pristina and what's 2 + 2?",
122
123 // Agent-level callbacks
124 OnAgentStart: func() {
125 fmt.Println("🎬 Agent started")
126 },
127 OnAgentFinish: func(result *ai.AgentResult) {
128 fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
129 },
130 OnStepStart: func(stepNumber int) {
131 stepCount++
132 fmt.Printf("📝 Step %d started\n", stepNumber+1)
133 },
134 OnStepFinish: func(stepResult ai.StepResult) {
135 fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
136 },
137 OnFinish: func(result *ai.AgentResult) {
138 fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
139 },
140 OnError: func(err error) {
141 fmt.Printf("❌ Error: %v\n", err)
142 },
143
144 // Stream part callbacks
145 OnWarnings: func(warnings []ai.CallWarning) {
146 for _, warning := range warnings {
147 fmt.Printf("⚠️ Warning: %s\n", warning.Message)
148 }
149 },
150 OnTextStart: func(id string) {
151 fmt.Print("💭 Assistant: ")
152 },
153 OnTextDelta: func(id, text string) {
154 fmt.Print(text)
155 textBuffer.WriteString(text)
156 },
157 OnTextEnd: func(id string) {
158 fmt.Println()
159 },
160 OnReasoningStart: func(id string) {
161 fmt.Print("🤔 Thinking: ")
162 },
163 OnReasoningDelta: func(id, text string) {
164 reasoningBuffer.WriteString(text)
165 },
166 OnReasoningEnd: func(id string) {
167 if reasoningBuffer.Len() > 0 {
168 fmt.Printf("%s\n", reasoningBuffer.String())
169 reasoningBuffer.Reset()
170 }
171 },
172 OnToolInputStart: func(id, toolName string) {
173 fmt.Printf("🔧 Calling tool: %s\n", toolName)
174 },
175 OnToolInputDelta: func(id, delta string) {
176 // Could show tool input being built, but it's often noisy
177 },
178 OnToolInputEnd: func(id string) {
179 // Tool input complete
180 },
181 OnToolCall: func(toolCall ai.ToolCallContent) {
182 fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName)
183 fmt.Printf(" Input: %s\n", toolCall.Input)
184 },
185 OnToolResult: func(result ai.ToolResultContent) {
186 fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
187 switch output := result.Result.(type) {
188 case ai.ToolResultOutputContentText:
189 fmt.Printf(" %s\n", output.Text)
190 case ai.ToolResultOutputContentError:
191 fmt.Printf(" Error: %s\n", output.Error.Error())
192 }
193 },
194 OnSource: func(source ai.SourceContent) {
195 fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
196 },
197 OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) {
198 fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
199 },
200 OnStreamError: func(err error) {
201 fmt.Printf("💥 Stream error: %v\n", err)
202 },
203 }
204
205 // Execute streaming agent
206 result, err := agent.Stream(ctx, streamCall)
207 if err != nil {
208 fmt.Printf("❌ Agent failed: %v\n", err)
209 os.Exit(1)
210 }
211
212 // Display final results
213 fmt.Println()
214 fmt.Println("📋 Final Summary")
215 fmt.Println("================")
216 fmt.Printf("Steps executed: %d\n", len(result.Steps))
217 fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
218 result.TotalUsage.TotalTokens,
219 result.TotalUsage.InputTokens,
220 result.TotalUsage.OutputTokens)
221
222 if result.TotalUsage.ReasoningTokens > 0 {
223 fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
224 }
225
226 fmt.Printf("Final response: %s\n", result.Response.Content.Text())
227
228 // Show step details
229 fmt.Println()
230 fmt.Println("🔍 Step Details")
231 fmt.Println("===============")
232 for i, step := range result.Steps {
233 fmt.Printf("Step %d:\n", i+1)
234 fmt.Printf(" Finish reason: %s\n", step.FinishReason)
235 fmt.Printf(" Content types: ")
236
237 var contentTypes []string
238 for _, content := range step.Content {
239 contentTypes = append(contentTypes, string(content.GetType()))
240 }
241 fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
242
243 // Show tool calls and results
244 toolCalls := step.Content.ToolCalls()
245 if len(toolCalls) > 0 {
246 fmt.Printf(" Tool calls: ")
247 var toolNames []string
248 for _, tc := range toolCalls {
249 toolNames = append(toolNames, tc.ToolName)
250 }
251 fmt.Printf("%s\n", strings.Join(toolNames, ", "))
252 }
253
254 toolResults := step.Content.ToolResults()
255 if len(toolResults) > 0 {
256 fmt.Printf(" Tool results: %d\n", len(toolResults))
257 }
258
259 fmt.Printf(" Tokens: %d\n", step.Usage.TotalTokens)
260 fmt.Println()
261 }
262
263 fmt.Println("✨ Example completed successfully!")
264}
265