Detailed changes
@@ -104,12 +104,12 @@ type ToolCallRepairOptions struct {
}
type (
- PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
+ PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
OnStepFinishedFunction = func(step StepResult)
RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
)
-type AgentSettings struct {
+type agentSettings struct {
systemPrompt string
maxOutputTokens *int64
temperature *float64
@@ -129,7 +129,6 @@ type AgentSettings struct {
stopWhen []StopCondition
prepareStep PrepareStepFunction
repairToolCall RepairToolCallFunction
- onStepFinished OnStepFinishedFunction
onRetry OnRetryCallback
}
@@ -152,7 +151,6 @@ type AgentCall struct {
StopWhen []StopCondition
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction
- OnStepFinished OnStepFinishedFunction
}
type AgentStreamCall struct {
@@ -184,22 +182,22 @@ type AgentStreamCall struct {
OnError func(error) // Called when an error occurs
// Stream part callbacks - called for each corresponding stream part type
- OnChunk func(StreamPart) // Called for each stream part (catch-all)
- OnWarnings func(warnings []CallWarning) // Called for warnings
- OnTextStart func(id string) // Called when text starts
- OnTextDelta func(id, text string) // Called for text deltas
- OnTextEnd func(id string) // Called when text ends
- OnReasoningStart func(id string) // Called when reasoning starts
- OnReasoningDelta func(id, text string) // Called for reasoning deltas
- OnReasoningEnd func(id string) // Called when reasoning ends
- OnToolInputStart func(id, toolName string) // Called when tool input starts
- OnToolInputDelta func(id, delta string) // Called for tool input deltas
- OnToolInputEnd func(id string) // Called when tool input ends
- OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
- OnToolResult func(result ToolResultContent) // Called when tool execution completes
- OnSource func(source SourceContent) // Called for source references
- OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
- OnStreamError func(error) // Called when stream error occurs
+ OnChunk func(StreamPart) // Called for each stream part (catch-all)
+ OnWarnings func(warnings []CallWarning) // Called for warnings
+ OnTextStart func(id string) // Called when text starts
+ OnTextDelta func(id, text string) // Called for text deltas
+ OnTextEnd func(id string) // Called when text ends
+ OnReasoningStart func(id string) // Called when reasoning starts
+ OnReasoningDelta func(id, text string) // Called for reasoning deltas
+ OnReasoningEnd func(id string, reasoning ReasoningContent) // Called when reasoning ends
+ OnToolInputStart func(id, toolName string) // Called when tool input starts
+ OnToolInputDelta func(id, delta string) // Called for tool input deltas
+ OnToolInputEnd func(id string) // Called when tool input ends
+ OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
+ OnToolResult func(result ToolResultContent) // Called when tool execution completes
+ OnSource func(source SourceContent) // Called for source references
+ OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes
+ OnStreamError func(error) // Called when stream error occurs
}
type AgentResult struct {
@@ -214,14 +212,14 @@ type Agent interface {
Stream(context.Context, AgentStreamCall) (*AgentResult, error)
}
-type agentOption = func(*AgentSettings)
+type AgentOption = func(*agentSettings)
type agent struct {
- settings AgentSettings
+ settings agentSettings
}
-func NewAgent(model LanguageModel, opts ...agentOption) Agent {
- settings := AgentSettings{
+func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
+ settings := agentSettings{
model: model,
}
for _, o := range opts {
@@ -260,9 +258,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
call.RepairToolCall = a.settings.repairToolCall
}
- if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
- call.OnStepFinished = a.settings.onStepFinished
- }
if call.OnRetry == nil && a.settings.onRetry != nil {
call.OnRetry = a.settings.onRetry
}
@@ -311,12 +306,15 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
disableAllTools := false
if opts.PrepareStep != nil {
- prepared := opts.PrepareStep(PrepareStepFunctionOptions{
+ prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
Model: stepModel,
Steps: steps,
StepNumber: len(steps),
Messages: stepInputMessages,
})
+ if err != nil {
+ return nil, err
+ }
// Apply prepared step modifications
if prepared.Messages != nil {
@@ -423,10 +421,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
Messages: currentStepMessages,
}
steps = append(steps, stepResult)
- if opts.OnStepFinished != nil {
- opts.OnStepFinished(stepResult)
- }
-
shouldStop := isStopConditionMet(opts.StopWhen, steps)
if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
@@ -709,12 +703,15 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
// Apply step preparation if provided
if call.PrepareStep != nil {
- prepared := call.PrepareStep(PrepareStepFunctionOptions{
+ prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
Model: stepModel,
Steps: steps,
StepNumber: stepNumber,
Messages: stepInputMessages,
})
+ if err != nil {
+ return nil, err
+ }
if prepared.Messages != nil {
stepInputMessages = prepared.Messages
@@ -925,78 +922,72 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files ..
return preparedPrompt, nil
}
-func WithSystemPrompt(prompt string) agentOption {
- return func(s *AgentSettings) {
+func WithSystemPrompt(prompt string) AgentOption {
+ return func(s *agentSettings) {
s.systemPrompt = prompt
}
}
-func WithMaxOutputTokens(tokens int64) agentOption {
- return func(s *AgentSettings) {
+func WithMaxOutputTokens(tokens int64) AgentOption {
+ return func(s *agentSettings) {
s.maxOutputTokens = &tokens
}
}
-func WithTemperature(temp float64) agentOption {
- return func(s *AgentSettings) {
+func WithTemperature(temp float64) AgentOption {
+ return func(s *agentSettings) {
s.temperature = &temp
}
}
-func WithTopP(topP float64) agentOption {
- return func(s *AgentSettings) {
+func WithTopP(topP float64) AgentOption {
+ return func(s *agentSettings) {
s.topP = &topP
}
}
-func WithTopK(topK int64) agentOption {
- return func(s *AgentSettings) {
+func WithTopK(topK int64) AgentOption {
+ return func(s *agentSettings) {
s.topK = &topK
}
}
-func WithPresencePenalty(penalty float64) agentOption {
- return func(s *AgentSettings) {
+func WithPresencePenalty(penalty float64) AgentOption {
+ return func(s *agentSettings) {
s.presencePenalty = &penalty
}
}
-func WithFrequencyPenalty(penalty float64) agentOption {
- return func(s *AgentSettings) {
+func WithFrequencyPenalty(penalty float64) AgentOption {
+ return func(s *agentSettings) {
s.frequencyPenalty = &penalty
}
}
-func WithTools(tools ...AgentTool) agentOption {
- return func(s *AgentSettings) {
+func WithTools(tools ...AgentTool) AgentOption {
+ return func(s *agentSettings) {
s.tools = append(s.tools, tools...)
}
}
-func WithStopConditions(conditions ...StopCondition) agentOption {
- return func(s *AgentSettings) {
+func WithStopConditions(conditions ...StopCondition) AgentOption {
+ return func(s *agentSettings) {
s.stopWhen = append(s.stopWhen, conditions...)
}
}
-func WithPrepareStep(fn PrepareStepFunction) agentOption {
- return func(s *AgentSettings) {
+func WithPrepareStep(fn PrepareStepFunction) AgentOption {
+ return func(s *agentSettings) {
s.prepareStep = fn
}
}
-func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
- return func(s *AgentSettings) {
+func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
+ return func(s *agentSettings) {
s.repairToolCall = fn
}
}
-func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
- return func(s *AgentSettings) {
- s.onStepFinished = fn
- }
-}
-
// processStepStream processes a single step's stream and returns the step result
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
var stepContent []Content
@@ -1069,11 +1060,14 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
Text: text,
ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
})
+ if opts.OnReasoningEnd != nil {
+ opts.OnReasoningEnd(part.ID, ReasoningContent{
+ Text: text,
+ ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+ })
+ }
delete(activeTextContent, part.ID)
}
- if opts.OnReasoningEnd != nil {
- opts.OnReasoningEnd(part.ID)
- }
case StreamPartTypeToolInputStart:
activeToolCalls[part.ID] = &ToolCallContent{
@@ -1194,14 +1188,14 @@ func addUsage(a, b Usage) Usage {
}
}
-func WithHeaders(headers map[string]string) agentOption {
- return func(s *AgentSettings) {
+func WithHeaders(headers map[string]string) AgentOption {
+ return func(s *agentSettings) {
s.headers = headers
}
}
-func WithProviderOptions(providerOptions ProviderOptions) agentOption {
- return func(s *AgentSettings) {
+func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
+ return func(s *agentSettings) {
s.providerOptions = providerOptions
}
}
@@ -145,7 +145,7 @@ func TestStreamingAgentCallbacks(t *testing.T) {
OnReasoningDelta: func(id, text string) {
callbacks["OnReasoningDelta"] = true
},
- OnReasoningEnd: func(id string) {
+ OnReasoningEnd: func(id string, content ReasoningContent) {
callbacks["OnReasoningEnd"] = true
},
OnToolInputStart: func(id, toolName string) {
@@ -166,7 +166,7 @@ func TestStreamingAgentCallbacks(t *testing.T) {
OnSource: func(source SourceContent) {
callbacks["OnSource"] = true
},
- OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
+ OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) {
callbacks["OnStreamFinish"] = true
},
OnStreamError: func(err error) {
@@ -962,13 +962,13 @@ func TestPrepareStep(t *testing.T) {
},
}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
return PrepareStepResult{
Model: options.Model,
Messages: options.Messages,
System: &newSystem,
- }
+ }, nil
}
agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
@@ -999,13 +999,13 @@ func TestPrepareStep(t *testing.T) {
},
}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
toolChoice := ToolChoiceNone
return PrepareStepResult{
Model: options.Model,
Messages: options.Messages,
ToolChoice: &toolChoice,
- }
+ }, nil
}
agent := NewAgent(model)
@@ -1044,13 +1044,13 @@ func TestPrepareStep(t *testing.T) {
tool2 := &mockTool{name: "tool2", description: "Tool 2"}
tool3 := &mockTool{name: "tool3", description: "Tool 3"}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
activeTools := []string{"tool2"} // Only tool2 should be active
return PrepareStepResult{
Model: options.Model,
Messages: options.Messages,
ActiveTools: activeTools,
- }
+ }, nil
}
agent := NewAgent(model, WithTools(tool1, tool2, tool3))
@@ -1084,12 +1084,12 @@ func TestPrepareStep(t *testing.T) {
tool1 := &mockTool{name: "tool1", description: "Tool 1"}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
return PrepareStepResult{
Model: options.Model,
Messages: options.Messages,
DisableAllTools: true, // Disable all tools for this step
- }
+ }, nil
}
agent := NewAgent(model, WithTools(tool1))
@@ -1139,7 +1139,7 @@ func TestPrepareStep(t *testing.T) {
tool1 := &mockTool{name: "tool1", description: "Tool 1"}
tool2 := &mockTool{name: "tool2", description: "Tool 2"}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
newSystem := "Step-specific system"
toolChoice := SpecificToolChoice("tool1")
activeTools := []string{"tool1"}
@@ -1149,7 +1149,7 @@ func TestPrepareStep(t *testing.T) {
System: &newSystem,
ToolChoice: &toolChoice,
ActiveTools: activeTools,
- }
+ }, nil
}
agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
@@ -1202,7 +1202,7 @@ func TestPrepareStep(t *testing.T) {
tool1 := &mockTool{name: "tool1", description: "Tool 1"}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
// All optional fields are nil, should use parent values
return PrepareStepResult{
Model: options.Model,
@@ -1210,7 +1210,7 @@ func TestPrepareStep(t *testing.T) {
System: nil, // Use parent
ToolChoice: nil, // Use parent (auto)
ActiveTools: nil, // Use parent (all tools)
- }
+ }, nil
}
agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
@@ -1251,12 +1251,12 @@ func TestPrepareStep(t *testing.T) {
tool1 := &mockTool{name: "tool1", description: "Tool 1"}
tool2 := &mockTool{name: "tool2", description: "Tool 2"}
- prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
return PrepareStepResult{
Model: options.Model,
Messages: options.Messages,
ActiveTools: []string{}, // Empty slice means all tools
- }
+ }, nil
}
agent := NewAgent(model, WithTools(tool1, tool2))
@@ -78,7 +78,22 @@ type Message struct {
ProviderOptions ProviderOptions `json:"provider_options"`
}
-func AsContentType[T MessagePart](content MessagePart) (T, bool) {
+func AsContentType[T Content](content Content) (T, bool) {
+ var zero T
+ if content == nil {
+ return zero, false
+ }
+ switch v := any(content).(type) {
+ case T:
+ return v, true
+ case *T:
+ return *v, true
+ default:
+ return zero, false
+ }
+}
+
+func AsMessagePart[T MessagePart](content MessagePart) (T, bool) {
var zero T
if content == nil {
return zero, false
@@ -96,6 +111,7 @@ func AsContentType[T MessagePart](content MessagePart) (T, bool) {
// MessagePart represents a part of a message content.
type MessagePart interface {
GetType() ContentType
+ Options() ProviderOptions
}
// TextPart represents text content in a message.
@@ -109,6 +125,10 @@ func (t TextPart) GetType() ContentType {
return ContentTypeText
}
+func (t TextPart) Options() ProviderOptions {
+ return t.ProviderOptions
+}
+
// ReasoningPart represents reasoning content in a message.
type ReasoningPart struct {
Text string `json:"text"`
@@ -120,6 +140,10 @@ func (r ReasoningPart) GetType() ContentType {
return ContentTypeReasoning
}
+func (r ReasoningPart) Options() ProviderOptions {
+ return r.ProviderOptions
+}
+
// FilePart represents file content in a message.
type FilePart struct {
Filename string `json:"filename"`
@@ -133,6 +157,10 @@ func (f FilePart) GetType() ContentType {
return ContentTypeFile
}
+func (f FilePart) Options() ProviderOptions {
+ return f.ProviderOptions
+}
+
// ToolCallPart represents a tool call in a message.
type ToolCallPart struct {
ToolCallID string `json:"tool_call_id"`
@@ -147,6 +175,10 @@ func (t ToolCallPart) GetType() ContentType {
return ContentTypeToolCall
}
+func (t ToolCallPart) Options() ProviderOptions {
+ return t.ProviderOptions
+}
+
// ToolResultPart represents a tool result in a message.
type ToolResultPart struct {
ToolCallID string `json:"tool_call_id"`
@@ -159,6 +191,10 @@ func (t ToolResultPart) GetType() ContentType {
return ContentTypeToolResult
}
+func (t ToolResultPart) Options() ProviderOptions {
+ return t.ProviderOptions
+}
+
// ToolResultContentType represents the type of tool result output.
type ToolResultContentType string
@@ -163,7 +163,7 @@ func main() {
OnReasoningDelta: func(id, text string) {
reasoningBuffer.WriteString(text)
},
- OnReasoningEnd: func(id string) {
+ OnReasoningEnd: func(id string, content ai.ReasoningContent) {
if reasoningBuffer.Len() > 0 {
fmt.Printf("%s\n", reasoningBuffer.String())
reasoningBuffer.Reset()
@@ -159,15 +159,16 @@ type StreamPart struct {
URL string `json:"url"`
Title string `json:"title"`
- ProviderMetadata ProviderOptions `json:"provider_metadata"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata"`
}
type StreamResponse = iter.Seq[StreamPart]
type ToolChoice string
const (
- ToolChoiceNone ToolChoice = "none"
- ToolChoiceAuto ToolChoice = "auto"
+ ToolChoiceNone ToolChoice = "none"
+ ToolChoiceAuto ToolChoice = "auto"
+ ToolChoiceRequired ToolChoice = "required"
)
func SpecificToolChoice(name string) ToolChoice {
@@ -110,6 +110,9 @@ type funcToolWrapper[TInput any] struct {
}
func (w *funcToolWrapper[TInput]) Info() ToolInfo {
+ if w.schema.Required == nil {
+ w.schema.Required = []string{}
+ }
return ToolInfo{
Name: w.name,
Description: w.description,