main.go

  1package main
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/ai"
 10	"github.com/charmbracelet/crush/internal/ai/providers"
 11)
 12
 13func main() {
 14	// Check for API key
 15	apiKey := os.Getenv("OPENAI_API_KEY")
 16	if apiKey == "" {
 17		fmt.Println("❌ Please set OPENAI_API_KEY environment variable")
 18		fmt.Println("   export OPENAI_API_KEY=your_api_key_here")
 19		os.Exit(1)
 20	}
 21
 22	fmt.Println("🚀 Streaming Agent Example")
 23	fmt.Println("==========================")
 24	fmt.Println()
 25
 26	// Create OpenAI provider and model
 27	provider := providers.NewOpenAIProvider(
 28		providers.WithOpenAIApiKey(apiKey),
 29	)
 30	model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses
 31
 32	// Define input types for type-safe tools
 33	type WeatherInput struct {
 34		Location string `json:"location" description:"The city and country, e.g. 'London, UK'"`
 35		Unit     string `json:"unit,omitempty" enum:"celsius,fahrenheit" description:"Temperature unit (celsius or fahrenheit)"`
 36	}
 37
 38	type CalculatorInput struct {
 39		Expression string `json:"expression" description:"Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')"`
 40	}
 41
 42	// Create weather tool using the new type-safe API
 43	weatherTool := ai.NewTypedToolFunc(
 44		"get_weather",
 45		"Get the current weather for a specific location",
 46		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 47			// Simulate weather lookup with some fake data
 48			location := input.Location
 49			if location == "" {
 50				location = "Unknown"
 51			}
 52
 53			// Default to celsius if not specified
 54			unit := input.Unit
 55			if unit == "" {
 56				unit = "celsius"
 57			}
 58
 59			// Simulate different temperatures for different cities
 60			var temp string
 61			if strings.Contains(strings.ToLower(location), "pristina") {
 62				temp = "15°C"
 63				if unit == "fahrenheit" {
 64					temp = "59°F"
 65				}
 66			} else if strings.Contains(strings.ToLower(location), "london") {
 67				temp = "12°C"
 68				if unit == "fahrenheit" {
 69					temp = "54°F"
 70				}
 71			} else {
 72				temp = "22°C"
 73				if unit == "fahrenheit" {
 74					temp = "72°F"
 75				}
 76			}
 77
 78			weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
 79			return ai.NewTextResponse(weather), nil
 80		},
 81	)
 82
 83	// Create calculator tool using the new type-safe API
 84	calculatorTool := ai.NewTypedToolFunc(
 85		"calculate",
 86		"Perform basic mathematical calculations",
 87		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 88			// Simple calculator simulation
 89			expr := strings.TrimSpace(input.Expression)
 90			if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
 91				return ai.NewTextResponse("2 + 2 = 4"), nil
 92			} else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
 93				return ai.NewTextResponse("10 * 5 = 50"), nil
 94			} else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") {
 95				return ai.NewTextResponse("15 + 27 = 42"), nil
 96			}
 97			return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil
 98		},
 99	)
100
101	// Create agent with tools
102	agent := ai.NewAgent(
103		model,
104		ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
105		ai.WithTools(weatherTool, calculatorTool),
106	)
107
108	ctx := context.Background()
109
110	// Demonstrate streaming with comprehensive callbacks
111	fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
112	fmt.Println()
113
114	// Track streaming events
115	var stepCount int
116	var textBuffer strings.Builder
117	var reasoningBuffer strings.Builder
118
119	// Create streaming call with all callbacks
120	streamCall := ai.AgentStreamCall{
121		Prompt: "What's the weather in Pristina and what's 2 + 2?",
122
123		// Agent-level callbacks
124		OnAgentStart: func() {
125			fmt.Println("🎬 Agent started")
126		},
127		OnAgentFinish: func(result *ai.AgentResult) {
128			fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
129		},
130		OnStepStart: func(stepNumber int) {
131			stepCount++
132			fmt.Printf("📝 Step %d started\n", stepNumber+1)
133		},
134		OnStepFinish: func(stepResult ai.StepResult) {
135			fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
136		},
137		OnFinish: func(result *ai.AgentResult) {
138			fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
139		},
140		OnError: func(err error) {
141			fmt.Printf("❌ Error: %v\n", err)
142		},
143
144		// Stream part callbacks
145		OnWarnings: func(warnings []ai.CallWarning) {
146			for _, warning := range warnings {
147				fmt.Printf("⚠️  Warning: %s\n", warning.Message)
148			}
149		},
150		OnTextStart: func(id string) {
151			fmt.Print("💭 Assistant: ")
152		},
153		OnTextDelta: func(id, text string) {
154			fmt.Print(text)
155			textBuffer.WriteString(text)
156		},
157		OnTextEnd: func(id string) {
158			fmt.Println()
159		},
160		OnReasoningStart: func(id string) {
161			fmt.Print("🤔 Thinking: ")
162		},
163		OnReasoningDelta: func(id, text string) {
164			reasoningBuffer.WriteString(text)
165		},
166		OnReasoningEnd: func(id string) {
167			if reasoningBuffer.Len() > 0 {
168				fmt.Printf("%s\n", reasoningBuffer.String())
169				reasoningBuffer.Reset()
170			}
171		},
172		OnToolInputStart: func(id, toolName string) {
173			fmt.Printf("🔧 Calling tool: %s\n", toolName)
174		},
175		OnToolInputDelta: func(id, delta string) {
176			// Could show tool input being built, but it's often noisy
177		},
178		OnToolInputEnd: func(id string) {
179			// Tool input complete
180		},
181		OnToolCall: func(toolCall ai.ToolCallContent) {
182			fmt.Printf("🛠️  Tool call: %s\n", toolCall.ToolName)
183			fmt.Printf("   Input: %s\n", toolCall.Input)
184		},
185		OnToolResult: func(result ai.ToolResultContent) {
186			fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
187			switch output := result.Result.(type) {
188			case ai.ToolResultOutputContentText:
189				fmt.Printf("   %s\n", output.Text)
190			case ai.ToolResultOutputContentError:
191				fmt.Printf("   Error: %s\n", output.Error.Error())
192			}
193		},
194		OnSource: func(source ai.SourceContent) {
195			fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
196		},
197		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) {
198			fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
199		},
200		OnStreamError: func(err error) {
201			fmt.Printf("💥 Stream error: %v\n", err)
202		},
203	}
204
205	// Execute streaming agent
206	result, err := agent.Stream(ctx, streamCall)
207	if err != nil {
208		fmt.Printf("❌ Agent failed: %v\n", err)
209		os.Exit(1)
210	}
211
212	// Display final results
213	fmt.Println()
214	fmt.Println("📋 Final Summary")
215	fmt.Println("================")
216	fmt.Printf("Steps executed: %d\n", len(result.Steps))
217	fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
218		result.TotalUsage.TotalTokens,
219		result.TotalUsage.InputTokens,
220		result.TotalUsage.OutputTokens)
221
222	if result.TotalUsage.ReasoningTokens > 0 {
223		fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
224	}
225
226	fmt.Printf("Final response: %s\n", result.Response.Content.Text())
227
228	// Show step details
229	fmt.Println()
230	fmt.Println("🔍 Step Details")
231	fmt.Println("===============")
232	for i, step := range result.Steps {
233		fmt.Printf("Step %d:\n", i+1)
234		fmt.Printf("  Finish reason: %s\n", step.FinishReason)
235		fmt.Printf("  Content types: ")
236
237		var contentTypes []string
238		for _, content := range step.Content {
239			contentTypes = append(contentTypes, string(content.GetType()))
240		}
241		fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
242
243		// Show tool calls and results
244		toolCalls := step.Content.ToolCalls()
245		if len(toolCalls) > 0 {
246			fmt.Printf("  Tool calls: ")
247			var toolNames []string
248			for _, tc := range toolCalls {
249				toolNames = append(toolNames, tc.ToolName)
250			}
251			fmt.Printf("%s\n", strings.Join(toolNames, ", "))
252		}
253
254		toolResults := step.Content.ToolResults()
255		if len(toolResults) > 0 {
256			fmt.Printf("  Tool results: %d\n", len(toolResults))
257		}
258
259		fmt.Printf("  Tokens: %d\n", step.Usage.TotalTokens)
260		fmt.Println()
261	}
262
263	fmt.Println("✨ Example completed successfully!")
264}