main.go

  1package main
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/ai"
 10	"github.com/charmbracelet/crush/internal/ai/providers"
 11	"github.com/charmbracelet/crush/internal/llm/tools"
 12)
 13
 14// WeatherTool is a simple tool that simulates weather lookup
 15type WeatherTool struct{}
 16
 17func (w *WeatherTool) Info() tools.ToolInfo {
 18	return tools.ToolInfo{
 19		Name:        "get_weather",
 20		Description: "Get the current weather for a specific location",
 21		Parameters: map[string]any{
 22			"location": map[string]any{
 23				"type":        "string",
 24				"description": "The city and country, e.g. 'London, UK'",
 25			},
 26			"unit": map[string]any{
 27				"type":        "string",
 28				"description": "Temperature unit (celsius or fahrenheit)",
 29				"enum":        []string{"celsius", "fahrenheit"},
 30				"default":     "celsius",
 31			},
 32		},
 33		Required: []string{"location"},
 34	}
 35}
 36
 37func (w *WeatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 38	// Simulate weather lookup with some fake data
 39	location := "Unknown"
 40	if strings.Contains(params.Input, "pristina") || strings.Contains(params.Input, "Pristina") {
 41		location = "Pristina, Kosovo"
 42	} else if strings.Contains(params.Input, "london") || strings.Contains(params.Input, "London") {
 43		location = "London, UK"
 44	} else if strings.Contains(params.Input, "new york") || strings.Contains(params.Input, "New York") {
 45		location = "New York, USA"
 46	}
 47
 48	unit := "celsius"
 49	if strings.Contains(params.Input, "fahrenheit") {
 50		unit = "fahrenheit"
 51	}
 52
 53	var temp string
 54	if unit == "fahrenheit" {
 55		temp = "72°F"
 56	} else {
 57		temp = "22°C"
 58	}
 59
 60	weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
 61	return tools.NewTextResponse(weather), nil
 62}
 63
 64// CalculatorTool demonstrates a second tool for multi-tool scenarios
 65type CalculatorTool struct{}
 66
 67func (c *CalculatorTool) Info() tools.ToolInfo {
 68	return tools.ToolInfo{
 69		Name:        "calculate",
 70		Description: "Perform basic mathematical calculations",
 71		Parameters: map[string]any{
 72			"expression": map[string]any{
 73				"type":        "string",
 74				"description": "Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')",
 75			},
 76		},
 77		Required: []string{"expression"},
 78	}
 79}
 80
 81func (c *CalculatorTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 82	// Simple calculator simulation
 83	expr := strings.ReplaceAll(params.Input, "\"", "")
 84	if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
 85		return tools.NewTextResponse("2 + 2 = 4"), nil
 86	} else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
 87		return tools.NewTextResponse("10 * 5 = 50"), nil
 88	}
 89	return tools.NewTextResponse("I can calculate simple expressions like '2 + 2' or '10 * 5'"), nil
 90}
 91
 92func main() {
 93	// Check for API key
 94	apiKey := os.Getenv("OPENAI_API_KEY")
 95	if apiKey == "" {
 96		fmt.Println("❌ Please set OPENAI_API_KEY environment variable")
 97		fmt.Println("   export OPENAI_API_KEY=your_api_key_here")
 98		os.Exit(1)
 99	}
100
101	fmt.Println("🚀 Streaming Agent Example")
102	fmt.Println("==========================")
103	fmt.Println()
104
105	// Create OpenAI provider and model
106	provider := providers.NewOpenAIProvider(
107		providers.WithOpenAIApiKey(apiKey),
108	)
109	model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses
110
111	// Create agent with tools
112	agent := ai.NewAgent(
113		model,
114		ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
115		ai.WithTools(&WeatherTool{}, &CalculatorTool{}),
116	)
117
118	ctx := context.Background()
119
120	// Demonstrate streaming with comprehensive callbacks
121	fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
122	fmt.Println()
123
124	// Track streaming events
125	var stepCount int
126	var textBuffer strings.Builder
127	var reasoningBuffer strings.Builder
128
129	// Create streaming call with all callbacks
130	streamCall := ai.AgentStreamCall{
131		Prompt: "What's the weather in Pristina and what's 2 + 2?",
132
133		// Agent-level callbacks
134		OnAgentStart: func() {
135			fmt.Println("🎬 Agent started")
136		},
137		OnAgentFinish: func(result *ai.AgentResult) {
138			fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
139		},
140		OnStepStart: func(stepNumber int) {
141			stepCount++
142			fmt.Printf("📝 Step %d started\n", stepNumber+1)
143		},
144		OnStepFinish: func(stepResult ai.StepResult) {
145			fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
146		},
147		OnFinish: func(result *ai.AgentResult) {
148			fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
149		},
150		OnError: func(err error) {
151			fmt.Printf("❌ Error: %v\n", err)
152		},
153
154		// Stream part callbacks
155		OnWarnings: func(warnings []ai.CallWarning) {
156			for _, warning := range warnings {
157				fmt.Printf("⚠️  Warning: %s\n", warning.Message)
158			}
159		},
160		OnTextStart: func(id string) {
161			fmt.Print("💭 Assistant: ")
162		},
163		OnTextDelta: func(id, text string) {
164			fmt.Print(text)
165			textBuffer.WriteString(text)
166		},
167		OnTextEnd: func(id string) {
168			fmt.Println()
169		},
170		OnReasoningStart: func(id string) {
171			fmt.Print("🤔 Thinking: ")
172		},
173		OnReasoningDelta: func(id, text string) {
174			reasoningBuffer.WriteString(text)
175		},
176		OnReasoningEnd: func(id string) {
177			if reasoningBuffer.Len() > 0 {
178				fmt.Printf("%s\n", reasoningBuffer.String())
179				reasoningBuffer.Reset()
180			}
181		},
182		OnToolInputStart: func(id, toolName string) {
183			fmt.Printf("🔧 Calling tool: %s\n", toolName)
184		},
185		OnToolInputDelta: func(id, delta string) {
186			// Could show tool input being built, but it's often noisy
187		},
188		OnToolInputEnd: func(id string) {
189			// Tool input complete
190		},
191		OnToolCall: func(toolCall ai.ToolCallContent) {
192			fmt.Printf("🛠️  Tool call: %s\n", toolCall.ToolName)
193			fmt.Printf("   Input: %s\n", toolCall.Input)
194		},
195		OnToolResult: func(result ai.ToolResultContent) {
196			fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
197			switch output := result.Result.(type) {
198			case ai.ToolResultOutputContentText:
199				fmt.Printf("   %s\n", output.Text)
200			case ai.ToolResultOutputContentError:
201				fmt.Printf("   Error: %s\n", output.Error.Error())
202			}
203		},
204		OnSource: func(source ai.SourceContent) {
205			fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
206		},
207		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) {
208			fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
209		},
210		OnStreamError: func(err error) {
211			fmt.Printf("💥 Stream error: %v\n", err)
212		},
213	}
214
215	// Execute streaming agent
216	result, err := agent.Stream(ctx, streamCall)
217	if err != nil {
218		fmt.Printf("❌ Agent failed: %v\n", err)
219		os.Exit(1)
220	}
221
222	// Display final results
223	fmt.Println()
224	fmt.Println("📋 Final Summary")
225	fmt.Println("================")
226	fmt.Printf("Steps executed: %d\n", len(result.Steps))
227	fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n", 
228		result.TotalUsage.TotalTokens, 
229		result.TotalUsage.InputTokens, 
230		result.TotalUsage.OutputTokens)
231	
232	if result.TotalUsage.ReasoningTokens > 0 {
233		fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
234	}
235
236	fmt.Printf("Final response: %s\n", result.Response.Content.Text())
237
238	// Show step details
239	fmt.Println()
240	fmt.Println("🔍 Step Details")
241	fmt.Println("===============")
242	for i, step := range result.Steps {
243		fmt.Printf("Step %d:\n", i+1)
244		fmt.Printf("  Finish reason: %s\n", step.FinishReason)
245		fmt.Printf("  Content types: ")
246		
247		var contentTypes []string
248		for _, content := range step.Content {
249			contentTypes = append(contentTypes, string(content.GetType()))
250		}
251		fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
252		
253		// Show tool calls and results
254		toolCalls := step.Content.ToolCalls()
255		if len(toolCalls) > 0 {
256			fmt.Printf("  Tool calls: ")
257			var toolNames []string
258			for _, tc := range toolCalls {
259				toolNames = append(toolNames, tc.ToolName)
260			}
261			fmt.Printf("%s\n", strings.Join(toolNames, ", "))
262		}
263		
264		toolResults := step.Content.ToolResults()
265		if len(toolResults) > 0 {
266			fmt.Printf("  Tool results: %d\n", len(toolResults))
267		}
268		
269		fmt.Printf("  Tokens: %d\n", step.Usage.TotalTokens)
270		fmt.Println()
271	}
272
273	fmt.Println("✨ Example completed successfully!")
274}