main.go

  1package main
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/ai"
 10	"github.com/charmbracelet/crush/internal/ai/providers"
 11)
 12
 13func main() {
 14	// Check for API key
 15	apiKey := os.Getenv("ANTHROPIC_API_KEY")
 16	if apiKey == "" {
 17		fmt.Println("❌ Please set ANTHROPIC_API_KEY environment variable")
 18		fmt.Println("   export ANTHROPIC_API_KEY=your_api_key_here")
 19		os.Exit(1)
 20	}
 21
 22	fmt.Println("🚀 Streaming Agent Example")
 23	fmt.Println("==========================")
 24	fmt.Println()
 25
 26	// Create OpenAI provider and model
 27	provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
 28	model, err := provider.LanguageModel("claude-sonnet-4-20250514")
 29	if err != nil {
 30		fmt.Println(err)
 31		return
 32	}
 33
 34	// Define input types for type-safe tools
 35	type WeatherInput struct {
 36		Location string `json:"location" description:"The city and country, e.g. 'London, UK'"`
 37		Unit     string `json:"unit,omitempty" enum:"celsius,fahrenheit" description:"Temperature unit (celsius or fahrenheit)"`
 38	}
 39
 40	type CalculatorInput struct {
 41		Expression string `json:"expression" description:"Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')"`
 42	}
 43
 44	// Create weather tool using the new type-safe API
 45	weatherTool := ai.NewAgentTool(
 46		"get_weather",
 47		"Get the current weather for a specific location",
 48		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 49			// Simulate weather lookup with some fake data
 50			location := input.Location
 51			if location == "" {
 52				location = "Unknown"
 53			}
 54
 55			// Default to celsius if not specified
 56			unit := input.Unit
 57			if unit == "" {
 58				unit = "celsius"
 59			}
 60
 61			// Simulate different temperatures for different cities
 62			var temp string
 63			if strings.Contains(strings.ToLower(location), "pristina") {
 64				temp = "15°C"
 65				if unit == "fahrenheit" {
 66					temp = "59°F"
 67				}
 68			} else if strings.Contains(strings.ToLower(location), "london") {
 69				temp = "12°C"
 70				if unit == "fahrenheit" {
 71					temp = "54°F"
 72				}
 73			} else {
 74				temp = "22°C"
 75				if unit == "fahrenheit" {
 76					temp = "72°F"
 77				}
 78			}
 79
 80			weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
 81			return ai.NewTextResponse(weather), nil
 82		},
 83	)
 84
 85	// Create calculator tool using the new type-safe API
 86	calculatorTool := ai.NewAgentTool(
 87		"calculate",
 88		"Perform basic mathematical calculations",
 89		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 90			// Simple calculator simulation
 91			expr := strings.TrimSpace(input.Expression)
 92			if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
 93				return ai.NewTextResponse("2 + 2 = 4"), nil
 94			} else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
 95				return ai.NewTextResponse("10 * 5 = 50"), nil
 96			} else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") {
 97				return ai.NewTextResponse("15 + 27 = 42"), nil
 98			}
 99			return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil
100		},
101	)
102
103	// Create agent with tools
104	agent := ai.NewAgent(
105		model,
106		ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
107		ai.WithTools(weatherTool, calculatorTool),
108	)
109
110	ctx := context.Background()
111
112	// Demonstrate streaming with comprehensive callbacks
113	fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
114	fmt.Println()
115
116	// Track streaming events
117	var stepCount int
118	var textBuffer strings.Builder
119	var reasoningBuffer strings.Builder
120
121	// Create streaming call with all callbacks
122	streamCall := ai.AgentStreamCall{
123		Prompt: "What's the weather in Pristina and what's 2 + 2?",
124
125		// Agent-level callbacks
126		OnAgentStart: func() {
127			fmt.Println("🎬 Agent started")
128		},
129		OnAgentFinish: func(result *ai.AgentResult) {
130			fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
131		},
132		OnStepStart: func(stepNumber int) {
133			stepCount++
134			fmt.Printf("📝 Step %d started\n", stepNumber+1)
135		},
136		OnStepFinish: func(stepResult ai.StepResult) {
137			fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
138		},
139		OnFinish: func(result *ai.AgentResult) {
140			fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
141		},
142		OnError: func(err error) {
143			fmt.Printf("❌ Error: %v\n", err)
144		},
145
146		// Stream part callbacks
147		OnWarnings: func(warnings []ai.CallWarning) {
148			for _, warning := range warnings {
149				fmt.Printf("⚠️  Warning: %s\n", warning.Message)
150			}
151		},
152		OnTextStart: func(id string) {
153			fmt.Print("💭 Assistant: ")
154		},
155		OnTextDelta: func(id, text string) {
156			fmt.Print(text)
157			textBuffer.WriteString(text)
158		},
159		OnTextEnd: func(id string) {
160			fmt.Println()
161		},
162		OnReasoningStart: func(id string) {
163			fmt.Print("🤔 Thinking: ")
164		},
165		OnReasoningDelta: func(id, text string) {
166			reasoningBuffer.WriteString(text)
167		},
168		OnReasoningEnd: func(id string, content ai.ReasoningContent) {
169			if reasoningBuffer.Len() > 0 {
170				fmt.Printf("%s\n", reasoningBuffer.String())
171				reasoningBuffer.Reset()
172			}
173		},
174		OnToolInputStart: func(id, toolName string) {
175			fmt.Printf("🔧 Calling tool: %s\n", toolName)
176		},
177		OnToolInputDelta: func(id, delta string) {
178			// Could show tool input being built, but it's often noisy
179		},
180		OnToolInputEnd: func(id string) {
181			// Tool input complete
182		},
183		OnToolCall: func(toolCall ai.ToolCallContent) {
184			fmt.Printf("🛠️  Tool call: %s\n", toolCall.ToolName)
185			fmt.Printf("   Input: %s\n", toolCall.Input)
186		},
187		OnToolResult: func(result ai.ToolResultContent) {
188			fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
189			switch output := result.Result.(type) {
190			case ai.ToolResultOutputContentText:
191				fmt.Printf("   %s\n", output.Text)
192			case ai.ToolResultOutputContentError:
193				fmt.Printf("   Error: %s\n", output.Error.Error())
194			}
195		},
196		OnSource: func(source ai.SourceContent) {
197			fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
198		},
199		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) {
200			fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
201		},
202		OnStreamError: func(err error) {
203			fmt.Printf("💥 Stream error: %v\n", err)
204		},
205	}
206
207	// Execute streaming agent
208	result, err := agent.Stream(ctx, streamCall)
209	if err != nil {
210		fmt.Printf("❌ Agent failed: %v\n", err)
211		os.Exit(1)
212	}
213
214	// Display final results
215	fmt.Println()
216	fmt.Println("📋 Final Summary")
217	fmt.Println("================")
218	fmt.Printf("Steps executed: %d\n", len(result.Steps))
219	fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
220		result.TotalUsage.TotalTokens,
221		result.TotalUsage.InputTokens,
222		result.TotalUsage.OutputTokens)
223
224	if result.TotalUsage.ReasoningTokens > 0 {
225		fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
226	}
227
228	fmt.Printf("Final response: %s\n", result.Response.Content.Text())
229
230	// Show step details
231	fmt.Println()
232	fmt.Println("🔍 Step Details")
233	fmt.Println("===============")
234	for i, step := range result.Steps {
235		fmt.Printf("Step %d:\n", i+1)
236		fmt.Printf("  Finish reason: %s\n", step.FinishReason)
237		fmt.Printf("  Content types: ")
238
239		var contentTypes []string
240		for _, content := range step.Content {
241			contentTypes = append(contentTypes, string(content.GetType()))
242		}
243		fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
244
245		// Show tool calls and results
246		toolCalls := step.Content.ToolCalls()
247		if len(toolCalls) > 0 {
248			fmt.Printf("  Tool calls: ")
249			var toolNames []string
250			for _, tc := range toolCalls {
251				toolNames = append(toolNames, tc.ToolName)
252			}
253			fmt.Printf("%s\n", strings.Join(toolNames, ", "))
254		}
255
256		toolResults := step.Content.ToolResults()
257		if len(toolResults) > 0 {
258			fmt.Printf("  Tool results: %d\n", len(toolResults))
259		}
260
261		fmt.Printf("  Tokens: %d\n", step.Usage.TotalTokens)
262		fmt.Println()
263	}
264
265	fmt.Println("✨ Example completed successfully!")
266}