Detailed changes
@@ -0,0 +1,29 @@
+# CRUSH.md - Fantasy AI SDK
+
+## Build/Test/Lint Commands
+- **Build**: `go build ./...`
+- **Test all**: `task test` or `go test ./... -count=1`
+- **Test single**: `go test -run TestName ./package -v`
+- **Test with args**: `task test -- -v -run TestName`
+- **Lint**: `task lint` or `golangci-lint run`
+- **Format**: `task fmt` or `gofmt -s -w .`
+- **Modernize**: `task modernize` or `modernize -fix ./...`
+
+## Code Style Guidelines
+- **Package naming**: lowercase, single word (ai, openai, anthropic, google)
+- **Imports**: standard library first, then third-party, then local packages
+- **Error handling**: Use custom error types with structured fields, wrap with context
+- **Types**: Use type aliases for function signatures (`type Option = func(*options)`)
+- **Naming**: CamelCase for exported, camelCase for unexported
+- **Constants**: Use const blocks with descriptive names (ProviderName, DefaultURL)
+- **Structs**: Embed anonymous structs for composition (APICallError embeds *AIError)
+- **Functions**: Return error as last parameter, use context.Context as first param
+- **Testing**: Use testify/assert, table-driven tests, recorder pattern for HTTP mocking
+- **Comments**: Godoc format for exported functions, explain complex logic inline
+- **JSON**: Use struct tags for marshaling, handle empty values gracefully
+
+## Project Structure
+- `/ai` - Core AI abstractions and agent logic
+- `/openai`, `/anthropic`, `/google` - Provider implementations
+- `/providertests` - Cross-provider integration tests with VCR recordings
+- `/examples` - Usage examples for different patterns
@@ -4,8 +4,9 @@ import (
"context"
"fmt"
"reflect"
- "strings"
"testing"
+
+ "github.com/stretchr/testify/require"
)
// Example of a simple typed tool using the function approach
@@ -28,12 +29,9 @@ func TestTypedToolFuncExample(t *testing.T) {
// Check the tool info
info := tool.Info()
- if info.Name != "calculator" {
- t.Errorf("Expected tool name 'calculator', got %s", info.Name)
- }
- if len(info.Required) != 1 || info.Required[0] != "expression" {
- t.Errorf("Expected required field 'expression', got %v", info.Required)
- }
+ require.Equal(t, "calculator", info.Name)
+ require.Len(t, info.Required, 1)
+ require.Equal(t, "expression", info.Required[0])
// Test execution
call := ToolCall{
@@ -43,15 +41,9 @@ func TestTypedToolFuncExample(t *testing.T) {
}
result, err := tool.Run(context.Background(), call)
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if result.Content != "4" {
- t.Errorf("Expected result '4', got %s", result.Content)
- }
- if result.IsError {
- t.Errorf("Expected successful result, got error")
- }
+ require.NoError(t, err)
+ require.Equal(t, "4", result.Content)
+ require.False(t, result.IsError)
}
func TestEnumToolExample(t *testing.T) {
@@ -76,13 +68,10 @@ func TestEnumToolExample(t *testing.T) {
// Check that the schema includes enum values
info := tool.Info()
unitsParam, ok := info.Parameters["units"].(map[string]any)
- if !ok {
- t.Fatal("Expected units parameter to exist")
- }
+ require.True(t, ok, "Expected units parameter to exist")
enumValues, ok := unitsParam["enum"].([]any)
- if !ok || len(enumValues) != 2 {
- t.Errorf("Expected 2 enum values, got %v", enumValues)
- }
+ require.True(t, ok)
+ require.Len(t, enumValues, 2)
// Test execution with enum value
call := ToolCall{
@@ -92,15 +81,9 @@ func TestEnumToolExample(t *testing.T) {
}
result, err := tool.Run(context.Background(), call)
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- if !strings.Contains(result.Content, "San Francisco") {
- t.Errorf("Expected result to contain 'San Francisco', got %s", result.Content)
- }
- if !strings.Contains(result.Content, "72ยฐF") {
- t.Errorf("Expected result to contain '72ยฐF', got %s", result.Content)
- }
+ require.NoError(t, err)
+ require.Contains(t, result.Content, "San Francisco")
+ require.Contains(t, result.Content, "72ยฐF")
}
func TestEnumSupport(t *testing.T) {
@@ -113,30 +96,20 @@ func TestEnumSupport(t *testing.T) {
schema := generateSchema(reflect.TypeOf(WeatherInput{}))
- if schema.Type != "object" {
- t.Errorf("Expected schema type 'object', got %s", schema.Type)
- }
+ require.Equal(t, "object", schema.Type)
// Check units field has enum values
unitsSchema := schema.Properties["units"]
- if unitsSchema == nil {
- t.Fatal("Expected units property to exist")
- }
- if len(unitsSchema.Enum) != 3 {
- t.Errorf("Expected 3 enum values for units, got %d", len(unitsSchema.Enum))
- }
+ require.NotNil(t, unitsSchema, "Expected units property to exist")
+ require.Len(t, unitsSchema.Enum, 3)
expectedUnits := []string{"celsius", "fahrenheit", "kelvin"}
for i, expected := range expectedUnits {
- if unitsSchema.Enum[i] != expected {
- t.Errorf("Expected enum value %s, got %v", expected, unitsSchema.Enum[i])
- }
+ require.Equal(t, expected, unitsSchema.Enum[i])
}
// Check required fields (format should not be required due to omitempty)
expectedRequired := []string{"location", "units"}
- if len(schema.Required) != len(expectedRequired) {
- t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
- }
+ require.Len(t, schema.Required, len(expectedRequired))
}
func TestSchemaToParameters(t *testing.T) {
@@ -170,43 +143,24 @@ func TestSchemaToParameters(t *testing.T) {
// Check name parameter
nameParam, ok := params["name"].(map[string]any)
- if !ok {
- t.Fatal("Expected name parameter to exist")
- }
- if nameParam["type"] != "string" {
- t.Errorf("Expected name type 'string', got %v", nameParam["type"])
- }
- if nameParam["description"] != "The name field" {
- t.Errorf("Expected name description 'The name field', got %v", nameParam["description"])
- }
+ require.True(t, ok, "Expected name parameter to exist")
+ require.Equal(t, "string", nameParam["type"])
+ require.Equal(t, "The name field", nameParam["description"])
// Check age parameter with min/max
ageParam, ok := params["age"].(map[string]any)
- if !ok {
- t.Fatal("Expected age parameter to exist")
- }
- if ageParam["type"] != "integer" {
- t.Errorf("Expected age type 'integer', got %v", ageParam["type"])
- }
- if ageParam["minimum"] != 0.0 {
- t.Errorf("Expected age minimum 0, got %v", ageParam["minimum"])
- }
- if ageParam["maximum"] != 120.0 {
- t.Errorf("Expected age maximum 120, got %v", ageParam["maximum"])
- }
+ require.True(t, ok, "Expected age parameter to exist")
+ require.Equal(t, "integer", ageParam["type"])
+ require.Equal(t, 0.0, ageParam["minimum"])
+ require.Equal(t, 120.0, ageParam["maximum"])
// Check priority parameter with enum
priorityParam, ok := params["priority"].(map[string]any)
- if !ok {
- t.Fatal("Expected priority parameter to exist")
- }
- if priorityParam["type"] != "string" {
- t.Errorf("Expected priority type 'string', got %v", priorityParam["type"])
- }
+ require.True(t, ok, "Expected priority parameter to exist")
+ require.Equal(t, "string", priorityParam["type"])
enumValues, ok := priorityParam["enum"].([]any)
- if !ok || len(enumValues) != 3 {
- t.Errorf("Expected 3 enum values, got %v", enumValues)
- }
+ require.True(t, ok)
+ require.Len(t, enumValues, 3)
}
func TestGenerateSchemaBasicTypes(t *testing.T) {
@@ -258,9 +212,7 @@ func TestGenerateSchemaBasicTypes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
schema := generateSchema(reflect.TypeOf(tt.input))
- if schema.Type != tt.expected.Type {
- t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type)
- }
+ require.Equal(t, tt.expected.Type, schema.Type)
})
}
}
@@ -303,15 +255,9 @@ func TestGenerateSchemaArrayTypes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
schema := generateSchema(reflect.TypeOf(tt.input))
- if schema.Type != tt.expected.Type {
- t.Errorf("Expected type %s, got %s", tt.expected.Type, schema.Type)
- }
- if schema.Items == nil {
- t.Fatal("Expected items schema to exist")
- }
- if schema.Items.Type != tt.expected.Items.Type {
- t.Errorf("Expected items type %s, got %s", tt.expected.Items.Type, schema.Items.Type)
- }
+ require.Equal(t, tt.expected.Type, schema.Type)
+ require.NotNil(t, schema.Items, "Expected items schema to exist")
+ require.Equal(t, tt.expected.Items.Type, schema.Items.Type)
})
}
}
@@ -345,9 +291,7 @@ func TestGenerateSchemaMapTypes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
schema := generateSchema(reflect.TypeOf(tt.input))
- if schema.Type != tt.expected {
- t.Errorf("Expected type %s, got %s", tt.expected, schema.Type)
- }
+ require.Equal(t, tt.expected, schema.Type)
})
}
}
@@ -384,60 +328,36 @@ func TestGenerateSchemaStructTypes(t *testing.T) {
name: "simple struct",
input: SimpleStruct{},
validate: func(t *testing.T, schema Schema) {
- if schema.Type != "object" {
- t.Errorf("Expected type object, got %s", schema.Type)
- }
- if len(schema.Properties) != 2 {
- t.Errorf("Expected 2 properties, got %d", len(schema.Properties))
- }
- if schema.Properties["name"] == nil {
- t.Error("Expected name property to exist")
- }
- if schema.Properties["name"].Description != "The name field" {
- t.Errorf("Expected description 'The name field', got %s", schema.Properties["name"].Description)
- }
- if len(schema.Required) != 2 {
- t.Errorf("Expected 2 required fields, got %d", len(schema.Required))
- }
+ require.Equal(t, "object", schema.Type)
+ require.Len(t, schema.Properties, 2)
+ require.NotNil(t, schema.Properties["name"], "Expected name property to exist")
+ require.Equal(t, "The name field", schema.Properties["name"].Description)
+ require.Len(t, schema.Required, 2)
},
},
{
name: "struct with omitempty",
input: StructWithOmitEmpty{},
validate: func(t *testing.T, schema Schema) {
- if len(schema.Required) != 1 {
- t.Errorf("Expected 1 required field, got %d", len(schema.Required))
- }
- if schema.Required[0] != "required" {
- t.Errorf("Expected required field 'required', got %s", schema.Required[0])
- }
+ require.Len(t, schema.Required, 1)
+ require.Equal(t, "required", schema.Required[0])
},
},
{
name: "struct with json ignore",
input: StructWithJSONIgnore{},
validate: func(t *testing.T, schema Schema) {
- if len(schema.Properties) != 1 {
- t.Errorf("Expected 1 property, got %d", len(schema.Properties))
- }
- if schema.Properties["visible"] == nil {
- t.Error("Expected visible property to exist")
- }
- if schema.Properties["hidden"] != nil {
- t.Error("Expected hidden property to not exist")
- }
+ require.Len(t, schema.Properties, 1)
+ require.NotNil(t, schema.Properties["visible"], "Expected visible property to exist")
+ require.Nil(t, schema.Properties["hidden"], "Expected hidden property to not exist")
},
},
{
name: "struct without json tags",
input: StructWithoutJSONTags{},
validate: func(t *testing.T, schema Schema) {
- if schema.Properties["first_name"] == nil {
- t.Error("Expected first_name property to exist")
- }
- if schema.Properties["last_name"] == nil {
- t.Error("Expected last_name property to exist")
- }
+ require.NotNil(t, schema.Properties["first_name"], "Expected first_name property to exist")
+ require.NotNil(t, schema.Properties["last_name"], "Expected last_name property to exist")
},
},
}
@@ -461,23 +381,13 @@ func TestGenerateSchemaPointerTypes(t *testing.T) {
schema := generateSchema(reflect.TypeOf(StructWithPointers{}))
- if schema.Type != "object" {
- t.Errorf("Expected type object, got %s", schema.Type)
- }
+ require.Equal(t, "object", schema.Type)
- if schema.Properties["name"] == nil {
- t.Fatal("Expected name property to exist")
- }
- if schema.Properties["name"].Type != "string" {
- t.Errorf("Expected name type string, got %s", schema.Properties["name"].Type)
- }
+ require.NotNil(t, schema.Properties["name"], "Expected name property to exist")
+ require.Equal(t, "string", schema.Properties["name"].Type)
- if schema.Properties["age"] == nil {
- t.Fatal("Expected age property to exist")
- }
- if schema.Properties["age"].Type != "integer" {
- t.Errorf("Expected age type integer, got %s", schema.Properties["age"].Type)
- }
+ require.NotNil(t, schema.Properties["age"], "Expected age property to exist")
+ require.Equal(t, "integer", schema.Properties["age"].Type)
}
func TestGenerateSchemaNestedStructs(t *testing.T) {
@@ -495,25 +405,15 @@ func TestGenerateSchemaNestedStructs(t *testing.T) {
schema := generateSchema(reflect.TypeOf(Person{}))
- if schema.Type != "object" {
- t.Errorf("Expected type object, got %s", schema.Type)
- }
+ require.Equal(t, "object", schema.Type)
- if schema.Properties["address"] == nil {
- t.Fatal("Expected address property to exist")
- }
+ require.NotNil(t, schema.Properties["address"], "Expected address property to exist")
addressSchema := schema.Properties["address"]
- if addressSchema.Type != "object" {
- t.Errorf("Expected address type object, got %s", addressSchema.Type)
- }
+ require.Equal(t, "object", addressSchema.Type)
- if addressSchema.Properties["street"] == nil {
- t.Error("Expected street property in address to exist")
- }
- if addressSchema.Properties["city"] == nil {
- t.Error("Expected city property in address to exist")
- }
+ require.NotNil(t, addressSchema.Properties["street"], "Expected street property in address to exist")
+ require.NotNil(t, addressSchema.Properties["city"], "Expected city property in address to exist")
}
func TestGenerateSchemaRecursiveStructs(t *testing.T) {
@@ -526,23 +426,15 @@ func TestGenerateSchemaRecursiveStructs(t *testing.T) {
schema := generateSchema(reflect.TypeOf(Node{}))
- if schema.Type != "object" {
- t.Errorf("Expected type object, got %s", schema.Type)
- }
+ require.Equal(t, "object", schema.Type)
- if schema.Properties["value"] == nil {
- t.Error("Expected value property to exist")
- }
+ require.NotNil(t, schema.Properties["value"], "Expected value property to exist")
- if schema.Properties["next"] == nil {
- t.Error("Expected next property to exist")
- }
+ require.NotNil(t, schema.Properties["next"], "Expected next property to exist")
// The recursive reference should be handled gracefully
nextSchema := schema.Properties["next"]
- if nextSchema.Type != "object" {
- t.Errorf("Expected next type object, got %s", nextSchema.Type)
- }
+ require.Equal(t, "object", nextSchema.Type)
}
func TestGenerateSchemaWithEnumTags(t *testing.T) {
@@ -558,33 +450,21 @@ func TestGenerateSchemaWithEnumTags(t *testing.T) {
// Check level field
levelSchema := schema.Properties["level"]
- if levelSchema == nil {
- t.Fatal("Expected level property to exist")
- }
- if len(levelSchema.Enum) != 4 {
- t.Errorf("Expected 4 enum values for level, got %d", len(levelSchema.Enum))
- }
+ require.NotNil(t, levelSchema, "Expected level property to exist")
+ require.Len(t, levelSchema.Enum, 4)
expectedLevels := []string{"debug", "info", "warn", "error"}
for i, expected := range expectedLevels {
- if levelSchema.Enum[i] != expected {
- t.Errorf("Expected enum value %s, got %v", expected, levelSchema.Enum[i])
- }
+ require.Equal(t, expected, levelSchema.Enum[i])
}
// Check format field
formatSchema := schema.Properties["format"]
- if formatSchema == nil {
- t.Fatal("Expected format property to exist")
- }
- if len(formatSchema.Enum) != 2 {
- t.Errorf("Expected 2 enum values for format, got %d", len(formatSchema.Enum))
- }
+ require.NotNil(t, formatSchema, "Expected format property to exist")
+ require.Len(t, formatSchema.Enum, 2)
// Check required fields (optional should not be required due to omitempty)
expectedRequired := []string{"level", "format"}
- if len(schema.Required) != len(expectedRequired) {
- t.Errorf("Expected %d required fields, got %d", len(expectedRequired), len(schema.Required))
- }
+ require.Len(t, schema.Required, len(expectedRequired))
}
func TestGenerateSchemaComplexTypes(t *testing.T) {
@@ -601,45 +481,25 @@ func TestGenerateSchemaComplexTypes(t *testing.T) {
// Check string slice
stringSliceSchema := schema.Properties["string_slice"]
- if stringSliceSchema == nil {
- t.Fatal("Expected string_slice property to exist")
- }
- if stringSliceSchema.Type != "array" {
- t.Errorf("Expected string_slice type array, got %s", stringSliceSchema.Type)
- }
- if stringSliceSchema.Items.Type != "string" {
- t.Errorf("Expected string_slice items type string, got %s", stringSliceSchema.Items.Type)
- }
+ require.NotNil(t, stringSliceSchema, "Expected string_slice property to exist")
+ require.Equal(t, "array", stringSliceSchema.Type)
+ require.Equal(t, "string", stringSliceSchema.Items.Type)
// Check int map
intMapSchema := schema.Properties["int_map"]
- if intMapSchema == nil {
- t.Fatal("Expected int_map property to exist")
- }
- if intMapSchema.Type != "object" {
- t.Errorf("Expected int_map type object, got %s", intMapSchema.Type)
- }
+ require.NotNil(t, intMapSchema, "Expected int_map property to exist")
+ require.Equal(t, "object", intMapSchema.Type)
// Check nested slice
nestedSliceSchema := schema.Properties["nested_slice"]
- if nestedSliceSchema == nil {
- t.Fatal("Expected nested_slice property to exist")
- }
- if nestedSliceSchema.Type != "array" {
- t.Errorf("Expected nested_slice type array, got %s", nestedSliceSchema.Type)
- }
- if nestedSliceSchema.Items.Type != "object" {
- t.Errorf("Expected nested_slice items type object, got %s", nestedSliceSchema.Items.Type)
- }
+ require.NotNil(t, nestedSliceSchema, "Expected nested_slice property to exist")
+ require.Equal(t, "array", nestedSliceSchema.Type)
+ require.Equal(t, "object", nestedSliceSchema.Items.Type)
// Check interface
interfaceSchema := schema.Properties["interface"]
- if interfaceSchema == nil {
- t.Fatal("Expected interface property to exist")
- }
- if interfaceSchema.Type != "object" {
- t.Errorf("Expected interface type object, got %s", interfaceSchema.Type)
- }
+ require.NotNil(t, interfaceSchema, "Expected interface property to exist")
+ require.Equal(t, "object", interfaceSchema.Type)
}
func TestToSnakeCase(t *testing.T) {
@@ -664,9 +524,7 @@ func TestToSnakeCase(t *testing.T) {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
result := toSnakeCase(tt.input)
- if result != tt.expected {
- t.Errorf("toSnakeCase(%s) = %s, expected %s", tt.input, result, tt.expected)
- }
+ require.Equal(t, tt.expected, result, "toSnakeCase(%s)", tt.input)
})
}
}
@@ -740,21 +598,14 @@ func TestSchemaToParametersEdgeCases(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := schemaToParameters(tt.schema)
- if len(result) != len(tt.expected) {
- t.Errorf("Expected %d parameters, got %d", len(tt.expected), len(result))
- }
+ require.Len(t, result, len(tt.expected))
for key, expectedValue := range tt.expected {
- if result[key] == nil {
- t.Errorf("Expected parameter %s to exist", key)
- continue
- }
+ require.NotNil(t, result[key], "Expected parameter %s to exist", key)
// Deep comparison would be complex, so we'll check key properties
resultParam := result[key].(map[string]any)
expectedParam := expectedValue.(map[string]any)
for propKey, propValue := range expectedParam {
- if resultParam[propKey] != propValue {
- t.Errorf("Expected %s.%s = %v, got %v", key, propKey, propValue, resultParam[propKey])
- }
+ require.Equal(t, propValue, resultParam[propKey], "Expected %s.%s", key, propKey)
}
}
})
@@ -18,8 +18,8 @@ import (
)
const (
- ProviderName = "anthropic"
- DefaultURL = "https://api.anthropic.com"
+ Name = "anthropic"
+ DefaultURL = "https://api.anthropic.com"
)
type options struct {
@@ -45,7 +45,7 @@ func New(opts ...Option) ai.Provider {
}
providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
- providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
+ providerOptions.name = cmp.Or(providerOptions.name, Name)
return &provider{options: providerOptions}
}
@@ -97,7 +97,7 @@ func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
}
return languageModel{
modelID: modelID,
- provider: fmt.Sprintf("%s.messages", a.options.name),
+ provider: a.options.name,
options: a.options,
client: anthropic.NewClient(anthropicClientOptions...),
}, nil
@@ -2,8 +2,6 @@ package anthropic
import "github.com/charmbracelet/fantasy/ai"
-const Name = "anthropic"
-
type ProviderOptions struct {
SendReasoning *bool `json:"send_reasoning"`
Thinking *ThinkingProviderOption `json:"thinking"`
@@ -1,9 +1 @@
-{
- "language": "en",
- "version": "0.2",
- "flagWords": [],
- "words": [
- "mapstructure",
- "mapstructure"
- ]
-}
+{"language":"en","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr"],"version":"0.2","flagWords":[]}
@@ -17,6 +17,8 @@ import (
"google.golang.org/genai"
)
+const Name = "google"
+
type provider struct {
options options
}
@@ -38,7 +40,7 @@ func New(opts ...Option) ai.Provider {
o(&options)
}
- options.name = cmp.Or(options.name, "google")
+ options.name = cmp.Or(options.name, Name)
return &provider{
options: options,
@@ -101,7 +103,7 @@ func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
}
return &languageModel{
modelID: modelID,
- provider: fmt.Sprintf("%s.generative-ai", g.options.name),
+ provider: g.options.name,
providerOptions: g.options,
client: client,
}, nil
@@ -120,16 +122,26 @@ func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig
systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
- if providerOptions.ThinkingConfig != nil &&
- providerOptions.ThinkingConfig.IncludeThoughts != nil &&
- *providerOptions.ThinkingConfig.IncludeThoughts &&
- strings.HasPrefix(a.provider, "google.vertex.") {
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeOther,
- Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
- "and might not be supported or could behave unexpectedly with the current Google provider " +
- fmt.Sprintf("(%s)", a.provider),
- })
+ if providerOptions.ThinkingConfig != nil {
+ if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
+ *providerOptions.ThinkingConfig.IncludeThoughts &&
+ strings.HasPrefix(a.provider, "google.vertex.") {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
+ "and might not be supported or could behave unexpectedly with the current Google provider " +
+ fmt.Sprintf("(%s)", a.provider),
+ })
+ }
+
+ if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
+ *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
+ })
+ providerOptions.ThinkingConfig.ThinkingBudget = ai.IntOption(128)
+ }
}
isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
@@ -1,7 +1,5 @@
package google
-const Name = "google"
-
type ThinkingConfig struct {
ThinkingBudget *int64 `json:"thinking_budget"`
IncludeThoughts *bool `json:"include_thoughts"`
@@ -21,8 +21,8 @@ import (
)
const (
- ProviderName = "openai"
- DefaultURL = "https://api.openai.com/v1"
+ Name = "openai"
+ DefaultURL = "https://api.openai.com/v1"
)
type provider struct {
@@ -50,7 +50,7 @@ func New(opts ...Option) ai.Provider {
}
providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
- providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
+ providerOptions.name = cmp.Or(providerOptions.name, Name)
if providerOptions.organization != "" {
providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
@@ -124,7 +124,7 @@ func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
return languageModel{
modelID: modelID,
- provider: fmt.Sprintf("%s.chat", o.options.name),
+ provider: o.options.name,
options: o.options,
client: openai.NewClient(openaiClientOptions...),
}, nil
@@ -5,8 +5,6 @@ import (
"github.com/openai/openai-go/v2"
)
-const Name = "openai"
-
type ReasoningEffort string
const (
@@ -27,6 +27,12 @@ var languageModelBuilders = []builderPair{
{"google-gemini-2.5-pro", builderGoogleGemini25Pro},
}
+var thinkingLanguageModelBuilders = []builderPair{
+ {"openai-gpt-5", builderOpenaiGpt5},
+ {"anthropic-claude-sonnet", builderAnthropicClaudeSonnet4},
+ {"google-gemini-2.5-pro", builderGoogleGemini25Pro},
+}
+
func builderOpenaiGpt4o(r *recorder.Recorder) (ai.LanguageModel, error) {
provider := openai.New(
openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")),
@@ -43,6 +49,14 @@ func builderOpenaiGpt4oMini(r *recorder.Recorder) (ai.LanguageModel, error) {
return provider.LanguageModel("gpt-4o-mini")
}
+func builderOpenaiGpt5(r *recorder.Recorder) (ai.LanguageModel, error) {
+ provider := openai.New(
+ openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")),
+ openai.WithHTTPClient(&http.Client{Transport: r}),
+ )
+ return provider.LanguageModel("gpt-5")
+}
+
func builderAnthropicClaudeSonnet4(r *recorder.Recorder) (ai.LanguageModel, error) {
provider := anthropic.New(
anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")),
@@ -7,7 +7,11 @@ import (
"testing"
"github.com/charmbracelet/fantasy/ai"
+ "github.com/charmbracelet/fantasy/anthropic"
+ "github.com/charmbracelet/fantasy/google"
+ "github.com/charmbracelet/fantasy/openai"
_ "github.com/joho/godotenv/autoload"
+ "github.com/stretchr/testify/require"
)
func TestSimple(t *testing.T) {
@@ -16,9 +20,7 @@ func TestSimple(t *testing.T) {
r := newRecorder(t)
languageModel, err := pair.builder(r)
- if err != nil {
- t.Fatalf("failed to build language model: %v", err)
- }
+ require.NoError(t, err, "failed to build language model")
agent := ai.NewAgent(
languageModel,
@@ -27,16 +29,12 @@ func TestSimple(t *testing.T) {
result, err := agent.Generate(t.Context(), ai.AgentCall{
Prompt: "Say hi in Portuguese",
})
- if err != nil {
- t.Fatalf("failed to generate: %v", err)
- }
+ require.NoError(t, err, "failed to generate")
option1 := "Oi"
option2 := "Olรก"
got := result.Response.Content.Text()
- if !strings.Contains(got, option1) && !strings.Contains(got, option2) {
- t.Fatalf("unexpected response: got %q, want %q or %q", got, option1, option2)
- }
+ require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
})
}
}
@@ -47,9 +45,7 @@ func TestTool(t *testing.T) {
r := newRecorder(t)
languageModel, err := pair.builder(r)
- if err != nil {
- t.Fatalf("failed to build language model: %v", err)
- }
+ require.NoError(t, err, "failed to build language model")
type WeatherInput struct {
Location string `json:"location" description:"the city"`
@@ -71,16 +67,124 @@ func TestTool(t *testing.T) {
result, err := agent.Generate(t.Context(), ai.AgentCall{
Prompt: "What's the weather in Florence?",
})
- if err != nil {
- t.Fatalf("failed to generate: %v", err)
+ require.NoError(t, err, "failed to generate")
+
+ want1 := "Florence"
+ want2 := "40"
+ got := result.Response.Content.Text()
+ require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
+ })
+ }
+}
+
+func TestThinking(t *testing.T) {
+ for _, pair := range thinkingLanguageModelBuilders {
+ t.Run(pair.name, func(t *testing.T) {
+ r := newRecorder(t)
+
+ languageModel, err := pair.builder(r)
+ require.NoError(t, err, "failed to build language model")
+
+ type WeatherInput struct {
+ Location string `json:"location" description:"the city"`
}
+ weatherTool := ai.NewAgentTool(
+ "weather",
+ "Get weather information for a location",
+ func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
+ return ai.NewTextResponse("40 C"), nil
+ },
+ )
+
+ agent := ai.NewAgent(
+ languageModel,
+ ai.WithSystemPrompt("You are a helpful assistant"),
+ ai.WithTools(weatherTool),
+ )
+ result, err := agent.Generate(t.Context(), ai.AgentCall{
+ Prompt: "What's the weather in Florence, Italy?",
+ ProviderOptions: ai.ProviderOptions{
+ "anthropic": &anthropic.ProviderOptions{
+ Thinking: &anthropic.ThinkingProviderOption{
+ BudgetTokens: 10_000,
+ },
+ },
+ "google": &google.ProviderOptions{
+ ThinkingConfig: &google.ThinkingConfig{
+ ThinkingBudget: ai.IntOption(100),
+ IncludeThoughts: ai.BoolOption(true),
+ },
+ },
+ "openai": &openai.ProviderOptions{
+ ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
+ },
+ },
+ })
+ require.NoError(t, err, "failed to generate")
+
want1 := "Florence"
want2 := "40"
got := result.Response.Content.Text()
- if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
- t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
+ require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
+
+ testThinkingSteps(t, languageModel.Provider(), result.Steps)
+ })
+ }
+}
+
+func TestThinkingStreaming(t *testing.T) {
+ for _, pair := range thinkingLanguageModelBuilders {
+ t.Run(pair.name, func(t *testing.T) {
+ r := newRecorder(t)
+
+ languageModel, err := pair.builder(r)
+ require.NoError(t, err, "failed to build language model")
+
+ type WeatherInput struct {
+ Location string `json:"location" description:"the city"`
}
+
+ weatherTool := ai.NewAgentTool(
+ "weather",
+ "Get weather information for a location",
+ func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
+ return ai.NewTextResponse("40 C"), nil
+ },
+ )
+
+ agent := ai.NewAgent(
+ languageModel,
+ ai.WithSystemPrompt("You are a helpful assistant"),
+ ai.WithTools(weatherTool),
+ )
+ result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
+ Prompt: "What's the weather in Florence, Italy?",
+ ProviderOptions: ai.ProviderOptions{
+ "anthropic": &anthropic.ProviderOptions{
+ Thinking: &anthropic.ThinkingProviderOption{
+ BudgetTokens: 10_000,
+ },
+ },
+ "google": &google.ProviderOptions{
+ ThinkingConfig: &google.ThinkingConfig{
+ ThinkingBudget: ai.IntOption(100),
+ IncludeThoughts: ai.BoolOption(true),
+ },
+ },
+ "openai": &openai.ProviderOptions{
+ ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
+ },
+ },
+ })
+ require.NoError(t, err, "failed to generate")
+
+ want1 := "Florence"
+ want2 := "40"
+ got := result.Response.Content.Text()
+ require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
+
+ testThinkingSteps(t, languageModel.Provider(), result.Steps)
})
}
}
@@ -91,9 +195,7 @@ func TestStream(t *testing.T) {
r := newRecorder(t)
languageModel, err := pair.builder(r)
- if err != nil {
- t.Fatalf("failed to build language model: %v", err)
- }
+ require.NoError(t, err, "failed to build language model")
agent := ai.NewAgent(
languageModel,
@@ -118,32 +220,20 @@ func TestStream(t *testing.T) {
}
result, err := agent.Stream(t.Context(), streamCall)
- if err != nil {
- t.Fatalf("failed to stream: %v", err)
- }
+ require.NoError(t, err, "failed to stream")
finalText := result.Response.Content.Text()
- if finalText == "" {
- t.Fatal("expected non-empty response")
- }
+ require.NotEmpty(t, finalText, "expected non-empty response")
- if !strings.Contains(strings.ToLower(finalText), "uno") ||
- !strings.Contains(strings.ToLower(finalText), "dos") ||
- !strings.Contains(strings.ToLower(finalText), "tres") {
- t.Fatalf("unexpected response: %q", finalText)
- }
+ require.True(t, strings.Contains(strings.ToLower(finalText), "uno") &&
+ strings.Contains(strings.ToLower(finalText), "dos") &&
+ strings.Contains(strings.ToLower(finalText), "tres"), "unexpected response: %q", finalText)
- if textDeltaCount == 0 {
- t.Fatal("expected at least one text delta callback")
- }
+ require.Greater(t, textDeltaCount, 0, "expected at least one text delta callback")
- if stepCount == 0 {
- t.Fatal("expected at least one step finish callback")
- }
+ require.Greater(t, stepCount, 0, "expected at least one step finish callback")
- if collectedText.String() == "" {
- t.Fatal("expected collected text from deltas to be non-empty")
- }
+ require.NotEmpty(t, collectedText.String(), "expected collected text from deltas to be non-empty")
})
}
}
@@ -154,9 +244,7 @@ func TestStreamWithTools(t *testing.T) {
r := newRecorder(t)
languageModel, err := pair.builder(r)
- if err != nil {
- t.Fatalf("failed to build language model: %v", err)
- }
+ require.NoError(t, err, "failed to build language model")
type CalculatorInput struct {
A int `json:"a" description:"first number"`
@@ -190,9 +278,7 @@ func TestStreamWithTools(t *testing.T) {
},
OnToolCall: func(toolCall ai.ToolCallContent) error {
toolCallCount++
- if toolCall.ToolName != "add" {
- t.Errorf("unexpected tool name: %s", toolCall.ToolName)
- }
+ require.Equal(t, "add", toolCall.ToolName, "unexpected tool name")
return nil
},
OnToolResult: func(result ai.ToolResultContent) error {
@@ -202,22 +288,14 @@ func TestStreamWithTools(t *testing.T) {
}
result, err := agent.Stream(t.Context(), streamCall)
- if err != nil {
- t.Fatalf("failed to stream: %v", err)
- }
+ require.NoError(t, err, "failed to stream")
finalText := result.Response.Content.Text()
- if !strings.Contains(finalText, "42") {
- t.Fatalf("expected response to contain '42', got: %q", finalText)
- }
+ require.Contains(t, finalText, "42", "expected response to contain '42', got: %q", finalText)
- if toolCallCount == 0 {
- t.Fatal("expected at least one tool call")
- }
+ require.Greater(t, toolCallCount, 0, "expected at least one tool call")
- if toolResultCount == 0 {
- t.Fatal("expected at least one tool result")
- }
+ require.Greater(t, toolResultCount, 0, "expected at least one tool result")
})
}
}
@@ -0,0 +1,63 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 550
+ host: ""
+ body: "{\"max_tokens\":14096,\"messages\":[{\"content\":[{\"text\":\"What's the weather in Florence, Italy?\",\"type\":\"text\"}],\"role\":\"user\"}],\"model\":\"claude-sonnet-4-20250514\",\"system\":[{\"text\":\"You are a helpful assistant\",\"type\":\"text\"}],\"thinking\":{\"budget_tokens\":10000,\"type\":\"enabled\"},\"tool_choice\":{\"disable_parallel_tool_use\":false,\"type\":\"auto\"},\"tools\":[{\"input_schema\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"string\"}},\"required\":[\"location\"],\"type\":\"object\"},\"name\":\"weather\",\"description\":\"Get weather information for a location\"}]}"
+ headers:
+ Accept:
+ - application/json
+ Content-Type:
+ - application/json
+ User-Agent:
+ - Anthropic/Go 1.10.0
+ url: https://api.anthropic.com/v1/messages
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
+ uncompressed: true
@@ -0,0 +1,59 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 550
+ host: generativelanguage.googleapis.com
+ body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence, Italy?\"}],\"role\":\"user\"}],\"generationConfig\":{\"thinkingConfig\":{\"includeThoughts\":true,\"thinkingBudget\":128}},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n"
+ headers:
+ Content-Type:
+ - application/json
+ User-Agent:
+ - google-genai-sdk/1.23.0 gl-go/go1.25.0
+ url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
+ uncompressed: true
@@ -0,0 +1,63 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 458
+ host: ""
+ body: "{\"messages\":[{\"content\":\"You are a helpful assistant\",\"role\":\"system\"},{\"content\":\"What's the weather in Florence, Italy?\",\"role\":\"user\"}],\"model\":\"gpt-5\",\"reasoning_effort\":\"medium\",\"tool_choice\":\"auto\",\"tools\":[{\"function\":{\"name\":\"weather\",\"strict\":false,\"description\":\"Get weather information for a location\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"string\"}},\"required\":[\"location\"],\"type\":\"object\"}},\"type\":\"function\"}]}"
+ headers:
+ Accept:
+ - application/json
+ Content-Type:
+ - application/json
+ User-Agent:
+ - OpenAI/Go 2.3.0
+ url: https://api.openai.com/v1/chat/completions
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
+ uncompressed: true
@@ -0,0 +1,61 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 564
+ host: ""
+ body: "{\"max_tokens\":14096,\"messages\":[{\"content\":[{\"text\":\"What's the weather in Florence, Italy?\",\"type\":\"text\"}],\"role\":\"user\"}],\"model\":\"claude-sonnet-4-20250514\",\"system\":[{\"text\":\"You are a helpful assistant\",\"type\":\"text\"}],\"thinking\":{\"budget_tokens\":10000,\"type\":\"enabled\"},\"tool_choice\":{\"disable_parallel_tool_use\":false,\"type\":\"auto\"},\"tools\":[{\"input_schema\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"string\"}},\"required\":[\"location\"],\"type\":\"object\"},\"name\":\"weather\",\"description\":\"Get weather information for a location\"}],\"stream\":true}"
+ headers:
+ Accept:
+ - application/json
+ Content-Type:
+ - application/json
+ User-Agent:
+ - Anthropic/Go 1.10.0
+ url: https://api.anthropic.com/v1/messages
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
@@ -0,0 +1,63 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 550
+ host: generativelanguage.googleapis.com
+ body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence, Italy?\"}],\"role\":\"user\"}],\"generationConfig\":{\"thinkingConfig\":{\"includeThoughts\":true,\"thinkingBudget\":128}},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n"
+ form:
+ alt:
+ - sse
+ headers:
+ Content-Type:
+ - application/json
+ User-Agent:
+ - google-genai-sdk/1.23.0 gl-go/go1.25.0
+ url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
@@ -0,0 +1,61 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 512
+ host: ""
+ body: "{\"messages\":[{\"content\":\"You are a helpful assistant\",\"role\":\"system\"},{\"content\":\"What's the weather in Florence, Italy?\",\"role\":\"user\"}],\"model\":\"gpt-5\",\"reasoning_effort\":\"medium\",\"stream_options\":{\"include_usage\":true},\"tool_choice\":\"auto\",\"tools\":[{\"function\":{\"name\":\"weather\",\"strict\":false,\"description\":\"Get weather information for a location\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"string\"}},\"required\":[\"location\"],\"type\":\"object\"}},\"type\":\"function\"}],\"stream\":true}"
+ headers:
+ Accept:
+ - application/json
+ Content-Type:
+ - application/json
+ User-Agent:
+ - OpenAI/Go 2.3.0
+ url: https://api.openai.com/v1/chat/completions
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
@@ -0,0 +1,70 @@
+package providertests
+
+import (
+ "testing"
+
+ "github.com/charmbracelet/fantasy/ai"
+ "github.com/charmbracelet/fantasy/anthropic"
+ "github.com/charmbracelet/fantasy/google"
+ "github.com/stretchr/testify/require"
+)
+
+func testThinkingSteps(t *testing.T, providerName string, steps []ai.StepResult) {
+ switch providerName {
+ case anthropic.Name:
+ testAnthropicThinking(t, steps)
+ case google.Name:
+ testGoogleThinking(t, steps)
+ }
+}
+
+func testGoogleThinking(t *testing.T, steps []ai.StepResult) {
+ reasoningContentCount := 0
+ // Test if we got the signature
+ for _, step := range steps {
+ for _, msg := range step.Messages {
+ for _, content := range msg.Content {
+ if content.GetType() == ai.ContentTypeReasoning {
+ reasoningContentCount += 1
+ }
+ }
+ }
+ }
+ require.Greater(t, reasoningContentCount, 0)
+}
+
+func testAnthropicThinking(t *testing.T, steps []ai.StepResult) {
+ reasoningContentCount := 0
+ signaturesCount := 0
+ // Test if we got the signature
+ for _, step := range steps {
+ for _, msg := range step.Messages {
+ for _, content := range msg.Content {
+ if content.GetType() == ai.ContentTypeReasoning {
+ reasoningContentCount += 1
+ reasoningContent, ok := ai.AsContentType[ai.ReasoningPart](content)
+ if !ok {
+ continue
+ }
+ if len(reasoningContent.ProviderOptions) == 0 {
+ continue
+ }
+
+ anthropicReasoningMetadata, ok := reasoningContent.ProviderOptions[anthropic.Name]
+ if !ok {
+ continue
+ }
+ if reasoningContent.Text != "" {
+ if typed, ok := anthropicReasoningMetadata.(*anthropic.ReasoningOptionMetadata); ok {
+ require.NotEmpty(t, typed.Signature)
+ signaturesCount += 1
+ }
+ }
+ }
+ }
+ }
+ }
+ require.Greater(t, reasoningContentCount, 0)
+ require.Greater(t, signaturesCount, 0)
+ require.Equal(t, reasoningContentCount, signaturesCount)
+}