main.go

  1package main
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/ai"
 10	"github.com/charmbracelet/crush/internal/ai/providers"
 11)
 12
 13func main() {
 14	// Check for API key
 15	apiKey := os.Getenv("ANTHROPIC_API_KEY")
 16	if apiKey == "" {
 17		fmt.Println("❌ Please set ANTHROPIC_API_KEY environment variable")
 18		fmt.Println("   export ANTHROPIC_API_KEY=your_api_key_here")
 19		os.Exit(1)
 20	}
 21
 22	fmt.Println("🚀 Streaming Agent Example")
 23	fmt.Println("==========================")
 24	fmt.Println()
 25
 26	// Create OpenAI provider and model
 27	provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
 28	model, err := provider.LanguageModel("claude-sonnet-4-20250514")
 29	if err != nil {
 30		fmt.Println(err)
 31		return
 32	}
 33
 34	// Define input types for type-safe tools
 35	type WeatherInput struct {
 36		Location string `json:"location" description:"The city and country, e.g. 'London, UK'"`
 37		Unit     string `json:"unit,omitempty" enum:"celsius,fahrenheit" description:"Temperature unit (celsius or fahrenheit)"`
 38	}
 39
 40	type CalculatorInput struct {
 41		Expression string `json:"expression" description:"Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')"`
 42	}
 43
 44	// Create weather tool using the new type-safe API
 45	weatherTool := ai.NewAgentTool(
 46		"get_weather",
 47		"Get the current weather for a specific location",
 48		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 49			// Simulate weather lookup with some fake data
 50			location := input.Location
 51			if location == "" {
 52				location = "Unknown"
 53			}
 54
 55			// Default to celsius if not specified
 56			unit := input.Unit
 57			if unit == "" {
 58				unit = "celsius"
 59			}
 60
 61			// Simulate different temperatures for different cities
 62			var temp string
 63			if strings.Contains(strings.ToLower(location), "pristina") {
 64				temp = "15°C"
 65				if unit == "fahrenheit" {
 66					temp = "59°F"
 67				}
 68			} else if strings.Contains(strings.ToLower(location), "london") {
 69				temp = "12°C"
 70				if unit == "fahrenheit" {
 71					temp = "54°F"
 72				}
 73			} else {
 74				temp = "22°C"
 75				if unit == "fahrenheit" {
 76					temp = "72°F"
 77				}
 78			}
 79
 80			weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
 81			return ai.NewTextResponse(weather), nil
 82		},
 83	)
 84
 85	// Create calculator tool using the new type-safe API
 86	calculatorTool := ai.NewAgentTool(
 87		"calculate",
 88		"Perform basic mathematical calculations",
 89		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 90			// Simple calculator simulation
 91			expr := strings.TrimSpace(input.Expression)
 92			if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
 93				return ai.NewTextResponse("2 + 2 = 4"), nil
 94			} else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
 95				return ai.NewTextResponse("10 * 5 = 50"), nil
 96			} else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") {
 97				return ai.NewTextResponse("15 + 27 = 42"), nil
 98			}
 99			return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil
100		},
101	)
102
103	// Create agent with tools
104	agent := ai.NewAgent(
105		model,
106		ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
107		ai.WithTools(weatherTool, calculatorTool),
108	)
109
110	ctx := context.Background()
111
112	// Demonstrate streaming with comprehensive callbacks
113	fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
114	fmt.Println()
115
116	// Track streaming events
117	var stepCount int
118	var textBuffer strings.Builder
119	var reasoningBuffer strings.Builder
120
121	// Create streaming call with all callbacks
122	streamCall := ai.AgentStreamCall{
123		Prompt: "What's the weather in Pristina and what's 2 + 2?",
124
125		// Agent-level callbacks
126		OnAgentStart: func() {
127			fmt.Println("🎬 Agent started")
128		},
129		OnAgentFinish: func(result *ai.AgentResult) {
130			fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
131		},
132		OnStepStart: func(stepNumber int) {
133			stepCount++
134			fmt.Printf("📝 Step %d started\n", stepNumber+1)
135		},
136		OnStepFinish: func(stepResult ai.StepResult) {
137			fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
138		},
139		OnFinish: func(result *ai.AgentResult) {
140			fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
141		},
142		OnError: func(err error) {
143			fmt.Printf("❌ Error: %v\n", err)
144		},
145
146		// Stream part callbacks
147		OnWarnings: func(warnings []ai.CallWarning) error {
148			for _, warning := range warnings {
149				fmt.Printf("⚠️  Warning: %s\n", warning.Message)
150			}
151			return nil
152		},
153		OnTextStart: func(id string) error {
154			fmt.Print("💭 Assistant: ")
155			return nil
156		},
157		OnTextDelta: func(id, text string) error {
158			fmt.Print(text)
159			textBuffer.WriteString(text)
160			return nil
161		},
162		OnTextEnd: func(id string) error {
163			fmt.Println()
164			return nil
165		},
166		OnReasoningStart: func(id string) error {
167			fmt.Print("🤔 Thinking: ")
168			return nil
169		},
170		OnReasoningDelta: func(id, text string) error {
171			reasoningBuffer.WriteString(text)
172			return nil
173		},
174		OnReasoningEnd: func(id string, content ai.ReasoningContent) error {
175			if reasoningBuffer.Len() > 0 {
176				fmt.Printf("%s\n", reasoningBuffer.String())
177				reasoningBuffer.Reset()
178			}
179			return nil
180		},
181		OnToolInputStart: func(id, toolName string) error {
182			fmt.Printf("🔧 Calling tool: %s\n", toolName)
183			return nil
184		},
185		OnToolInputDelta: func(id, delta string) error {
186			// Could show tool input being built, but it's often noisy
187			return nil
188		},
189		OnToolInputEnd: func(id string) error {
190			// Tool input complete
191			return nil
192		},
193		OnToolCall: func(toolCall ai.ToolCallContent) error {
194			fmt.Printf("🛠️  Tool call: %s\n", toolCall.ToolName)
195			fmt.Printf("   Input: %s\n", toolCall.Input)
196			return nil
197		},
198		OnToolResult: func(result ai.ToolResultContent) error {
199			fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
200			switch output := result.Result.(type) {
201			case ai.ToolResultOutputContentText:
202				fmt.Printf("   %s\n", output.Text)
203			case ai.ToolResultOutputContentError:
204				fmt.Printf("   Error: %s\n", output.Error.Error())
205			}
206			return nil
207		},
208		OnSource: func(source ai.SourceContent) error {
209			fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
210			return nil
211		},
212		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) error {
213			fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
214			return nil
215		},
216	}
217
218	// Execute streaming agent
219	result, err := agent.Stream(ctx, streamCall)
220	if err != nil {
221		fmt.Printf("❌ Agent failed: %v\n", err)
222		os.Exit(1)
223	}
224
225	// Display final results
226	fmt.Println()
227	fmt.Println("📋 Final Summary")
228	fmt.Println("================")
229	fmt.Printf("Steps executed: %d\n", len(result.Steps))
230	fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
231		result.TotalUsage.TotalTokens,
232		result.TotalUsage.InputTokens,
233		result.TotalUsage.OutputTokens)
234
235	if result.TotalUsage.ReasoningTokens > 0 {
236		fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
237	}
238
239	fmt.Printf("Final response: %s\n", result.Response.Content.Text())
240
241	// Show step details
242	fmt.Println()
243	fmt.Println("🔍 Step Details")
244	fmt.Println("===============")
245	for i, step := range result.Steps {
246		fmt.Printf("Step %d:\n", i+1)
247		fmt.Printf("  Finish reason: %s\n", step.FinishReason)
248		fmt.Printf("  Content types: ")
249
250		var contentTypes []string
251		for _, content := range step.Content {
252			contentTypes = append(contentTypes, string(content.GetType()))
253		}
254		fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
255
256		// Show tool calls and results
257		toolCalls := step.Content.ToolCalls()
258		if len(toolCalls) > 0 {
259			fmt.Printf("  Tool calls: ")
260			var toolNames []string
261			for _, tc := range toolCalls {
262				toolNames = append(toolNames, tc.ToolName)
263			}
264			fmt.Printf("%s\n", strings.Join(toolNames, ", "))
265		}
266
267		toolResults := step.Content.ToolResults()
268		if len(toolResults) > 0 {
269			fmt.Printf("  Tool results: %d\n", len(toolResults))
270		}
271
272		fmt.Printf("  Tokens: %d\n", step.Usage.TotalTokens)
273		fmt.Println()
274	}
275
276	fmt.Println("✨ Example completed successfully!")
277}