diff --git a/ai/agent.go b/agent.go similarity index 99% rename from ai/agent.go rename to agent.go index 310f573f74165212c21d78299c737fd4e62b6de7..1f772797bc6b72b59f0bdd7d8147433186fcdef2 100644 --- a/ai/agent.go +++ b/agent.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "cmp" diff --git a/ai/agent_stream_test.go b/agent_stream_test.go similarity index 99% rename from ai/agent_stream_test.go rename to agent_stream_test.go index 38d9f5db324e53ff7cf9329b97643ec2686ea777..ea9009305c9c872a44a5074e5d361136250d1910 100644 --- a/ai/agent_stream_test.go +++ b/agent_stream_test.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "context" diff --git a/ai/agent_test.go b/agent_test.go similarity index 99% rename from ai/agent_test.go rename to agent_test.go index 728d9691f4ae34f622d2998c252601f0e8c070a4..2929e3c0c98100231f498a8084dc46247d57cb1f 100644 --- a/ai/agent_test.go +++ b/agent_test.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "context" diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index 57b19452073558becf33e90f06b3a8f9dfb4f3f3..1485538977e7064af6b2c62dcf84ac9f9f4cbb06 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -11,7 +11,7 @@ import ( "maps" "strings" - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" @@ -45,7 +45,7 @@ type provider struct { type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { providerOptions := options{ headers: map[string]string{}, } @@ -107,7 +107,7 @@ func WithHTTPClient(client option.HTTPClient) Option { } } -func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { +func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) { clientOptions := make([]option.RequestOption, 0, 5+len(a.options.headers)) if a.options.apiKey != "" { clientOptions = append(clientOptions, option.WithAPIKey(a.options.apiKey)) @@ -171,23 +171,23 @@ type languageModel struct { options options } -// Model implements ai.LanguageModel. +// Model implements fantasy.LanguageModel. func (a languageModel) Model() string { return a.modelID } -// Provider implements ai.LanguageModel. +// Provider implements fantasy.LanguageModel. func (a languageModel) Provider() string { return a.provider } -func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { +func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewParams, []fantasy.CallWarning, error) { params := &anthropic.MessageNewParams{} providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) + return nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) } } sendReasoning := true @@ -197,14 +197,14 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning) if call.FrequencyPenalty != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "FrequencyPenalty", }) } if call.PresencePenalty != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "PresencePenalty", }) } @@ -236,29 +236,29 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, } if isThinking { if thinkingBudget == 0 { - return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "") + return nil, nil, fantasy.NewUnsupportedFunctionalityError("thinking requires budget", "") } params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget) if call.Temperature != nil { params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "temperature", Details: "temperature is not supported when thinking is enabled", }) } if call.TopP != nil { params.TopP = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "TopP", Details: "TopP is not supported when thinking is enabled", }) } if call.TopK != nil { params.TopK = param.Opt[int64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "TopK", Details: "TopK is not supported when thinking is enabled", }) @@ -286,7 +286,7 @@ func (a *provider) Name() string { return Name } -func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl { +func getCacheControl(providerOptions fantasy.ProviderOptions) *CacheControl { if anthropicOptions, ok := providerOptions[Name]; ok { if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok { return &options.CacheControl @@ -295,7 +295,7 @@ func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl { return nil } -func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata { +func getReasoningMetadata(providerOptions fantasy.ProviderOptions) *ReasoningOptionMetadata { if anthropicOptions, ok := providerOptions[Name]; ok { if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok { return reasoning @@ -305,49 +305,49 @@ func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMe } type messageBlock struct { - Role ai.MessageRole - Messages []ai.Message + Role fantasy.MessageRole + Messages []fantasy.Message } -func groupIntoBlocks(prompt ai.Prompt) []*messageBlock { +func groupIntoBlocks(prompt fantasy.Prompt) []*messageBlock { var blocks []*messageBlock var currentBlock *messageBlock for _, msg := range prompt { switch msg.Role { - case ai.MessageRoleSystem: - if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem { + case fantasy.MessageRoleSystem: + if currentBlock == nil || currentBlock.Role != fantasy.MessageRoleSystem { currentBlock = &messageBlock{ - Role: ai.MessageRoleSystem, - Messages: []ai.Message{}, + Role: fantasy.MessageRoleSystem, + Messages: []fantasy.Message{}, } blocks = append(blocks, currentBlock) } currentBlock.Messages = append(currentBlock.Messages, msg) - case ai.MessageRoleUser: - if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser { + case fantasy.MessageRoleUser: + if currentBlock == nil || currentBlock.Role != fantasy.MessageRoleUser { currentBlock = &messageBlock{ - Role: ai.MessageRoleUser, - Messages: []ai.Message{}, + Role: fantasy.MessageRoleUser, + Messages: []fantasy.Message{}, } blocks = append(blocks, currentBlock) } currentBlock.Messages = append(currentBlock.Messages, msg) - case ai.MessageRoleAssistant: - if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant { + case fantasy.MessageRoleAssistant: + if currentBlock == nil || currentBlock.Role != fantasy.MessageRoleAssistant { currentBlock = &messageBlock{ - Role: ai.MessageRoleAssistant, - Messages: []ai.Message{}, + Role: fantasy.MessageRoleAssistant, + Messages: []fantasy.Message{}, } blocks = append(blocks, currentBlock) } currentBlock.Messages = append(currentBlock.Messages, msg) - case ai.MessageRoleTool: - if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser { + case fantasy.MessageRoleTool: + if currentBlock == nil || currentBlock.Role != fantasy.MessageRoleUser { currentBlock = &messageBlock{ - Role: ai.MessageRoleUser, - Messages: []ai.Message{}, + Role: fantasy.MessageRoleUser, + Messages: []fantasy.Message{}, } blocks = append(blocks, currentBlock) } @@ -357,10 +357,10 @@ func groupIntoBlocks(prompt ai.Prompt) []*messageBlock { return blocks } -func (a languageModel) toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) { +func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []fantasy.CallWarning) { for _, tool := range tools { - if tool.GetType() == ai.ToolTypeFunction { - ft, ok := tool.(ai.FunctionTool) + if tool.GetType() == fantasy.ToolTypeFunction { + ft, ok := tool.(fantasy.FunctionTool) if !ok { continue } @@ -391,8 +391,8 @@ func (a languageModel) toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disab continue } // TODO: handle provider tool calls - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedTool, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedTool, Tool: tool, Message: "tool is not supported", }) @@ -417,21 +417,21 @@ func (a languageModel) toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disab } switch *toolChoice { - case ai.ToolChoiceAuto: + case fantasy.ToolChoiceAuto: anthropicToolChoice = &anthropic.ToolChoiceUnionParam{ OfAuto: &anthropic.ToolChoiceAutoParam{ Type: "auto", DisableParallelToolUse: disableParallelToolUse, }, } - case ai.ToolChoiceRequired: + case fantasy.ToolChoiceRequired: anthropicToolChoice = &anthropic.ToolChoiceUnionParam{ OfAny: &anthropic.ToolChoiceAnyParam{ Type: "any", DisableParallelToolUse: disableParallelToolUse, }, } - case ai.ToolChoiceNone: + case fantasy.ToolChoiceNone: return anthropicTools, anthropicToolChoice, warnings default: anthropicToolChoice = &anthropic.ToolChoiceUnionParam{ @@ -445,16 +445,16 @@ func (a languageModel) toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disab return anthropicTools, anthropicToolChoice, warnings } -func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) { +func toPrompt(prompt fantasy.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []fantasy.CallWarning) { var systemBlocks []anthropic.TextBlockParam var messages []anthropic.MessageParam - var warnings []ai.CallWarning + var warnings []fantasy.CallWarning blocks := groupIntoBlocks(prompt) finishedSystemBlock := false for _, block := range blocks { switch block.Role { - case ai.MessageRoleSystem: + case fantasy.MessageRoleSystem: if finishedSystemBlock { // skip multiple system messages that are separated by user/assistant messages // TODO: see if we need to send error here? @@ -468,7 +468,7 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa if cacheControl == nil && isLastPart { cacheControl = getCacheControl(msg.ProviderOptions) } - text, ok := ai.AsMessagePart[ai.TextPart](part) + text, ok := fantasy.AsMessagePart[fantasy.TextPart](part) if !ok { continue } @@ -482,10 +482,10 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa } } - case ai.MessageRoleUser: + case fantasy.MessageRoleUser: var anthropicContent []anthropic.ContentBlockParamUnion for _, msg := range block.Messages { - if msg.Role == ai.MessageRoleUser { + if msg.Role == fantasy.MessageRoleUser { for i, part := range msg.Content { isLastPart := i == len(msg.Content)-1 cacheControl := getCacheControl(part.Options()) @@ -493,8 +493,8 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa cacheControl = getCacheControl(msg.ProviderOptions) } switch part.GetType() { - case ai.ContentTypeText: - text, ok := ai.AsMessagePart[ai.TextPart](part) + case fantasy.ContentTypeText: + text, ok := fantasy.AsMessagePart[fantasy.TextPart](part) if !ok { continue } @@ -507,8 +507,8 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{ OfText: textBlock, }) - case ai.ContentTypeFile: - file, ok := ai.AsMessagePart[ai.FilePart](part) + case fantasy.ContentTypeFile: + file, ok := fantasy.AsMessagePart[fantasy.FilePart](part) if !ok { continue } @@ -525,14 +525,14 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa anthropicContent = append(anthropicContent, imageBlock) } } - } else if msg.Role == ai.MessageRoleTool { + } else if msg.Role == fantasy.MessageRoleTool { for i, part := range msg.Content { isLastPart := i == len(msg.Content)-1 cacheControl := getCacheControl(part.Options()) if cacheControl == nil && isLastPart { cacheControl = getCacheControl(msg.ProviderOptions) } - result, ok := ai.AsMessagePart[ai.ToolResultPart](part) + result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part) if !ok { continue } @@ -540,8 +540,8 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa ToolUseID: result.ToolCallID, } switch result.Output.GetType() { - case ai.ToolResultContentTypeText: - content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output) + case fantasy.ToolResultContentTypeText: + content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output) if !ok { continue } @@ -552,8 +552,8 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa }, }, } - case ai.ToolResultContentTypeMedia: - content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output) + case fantasy.ToolResultContentTypeMedia: + content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Output) if !ok { continue } @@ -562,8 +562,8 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage, }, } - case ai.ToolResultContentTypeError: - content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output) + case fantasy.ToolResultContentTypeError: + content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output) if !ok { continue } @@ -586,7 +586,7 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa } } messages = append(messages, anthropic.NewUserMessage(anthropicContent...)) - case ai.MessageRoleAssistant: + case fantasy.MessageRoleAssistant: var anthropicContent []anthropic.ContentBlockParamUnion for _, msg := range block.Messages { for i, part := range msg.Content { @@ -596,8 +596,8 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa cacheControl = getCacheControl(msg.ProviderOptions) } switch part.GetType() { - case ai.ContentTypeText: - text, ok := ai.AsMessagePart[ai.TextPart](part) + case fantasy.ContentTypeText: + text, ok := fantasy.AsMessagePart[fantasy.TextPart](part) if !ok { continue } @@ -610,13 +610,13 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{ OfText: textBlock, }) - case ai.ContentTypeReasoning: - reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part) + case fantasy.ContentTypeReasoning: + reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part) if !ok { continue } if !sendReasoningData { - warnings = append(warnings, ai.CallWarning{ + warnings = append(warnings, fantasy.CallWarning{ Type: "other", Message: "sending reasoning content is disabled for this model", }) @@ -624,7 +624,7 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa } reasoningMetadata := getReasoningMetadata(part.Options()) if reasoningMetadata == nil { - warnings = append(warnings, ai.CallWarning{ + warnings = append(warnings, fantasy.CallWarning{ Type: "other", Message: "unsupported reasoning metadata", }) @@ -636,14 +636,14 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa } else if reasoningMetadata.RedactedData != "" { anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData)) } else { - warnings = append(warnings, ai.CallWarning{ + warnings = append(warnings, fantasy.CallWarning{ Type: "other", Message: "unsupported reasoning metadata", }) continue } - case ai.ContentTypeToolCall: - toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part) + case fantasy.ContentTypeToolCall: + toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part) if !ok { continue } @@ -662,7 +662,7 @@ func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockPa toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam() } anthropicContent = append(anthropicContent, toolUseBlock) - case ai.ContentTypeToolResult: + case fantasy.ContentTypeToolResult: // TODO: implement provider executed tool result } } @@ -683,7 +683,7 @@ func (o languageModel) handleError(err error) error { v := h[len(h)-1] headers[strings.ToLower(k)] = v } - return ai.NewAPICallError( + return fantasy.NewAPICallError( apiErr.Error(), apiErr.Request.URL.String(), string(requestDump), @@ -697,21 +697,21 @@ func (o languageModel) handleError(err error) error { return err } -func mapFinishReason(finishReason string) ai.FinishReason { +func mapFinishReason(finishReason string) fantasy.FinishReason { switch finishReason { case "end_turn", "pause_turn", "stop_sequence": - return ai.FinishReasonStop + return fantasy.FinishReasonStop case "max_tokens": - return ai.FinishReasonLength + return fantasy.FinishReasonLength case "tool_use": - return ai.FinishReasonToolCalls + return fantasy.FinishReasonToolCalls default: - return ai.FinishReasonUnknown + return fantasy.FinishReasonUnknown } } -// Generate implements ai.LanguageModel. -func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +// Generate implements fantasy.LanguageModel. +func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings, err := a.prepareParams(call) if err != nil { return nil, err @@ -721,7 +721,7 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response return nil, a.handleError(err) } - var content []ai.Content + var content []fantasy.Content for _, block := range response.Content { switch block.Type { case "text": @@ -729,7 +729,7 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response if !ok { continue } - content = append(content, ai.TextContent{ + content = append(content, fantasy.TextContent{ Text: text.Text, }) case "thinking": @@ -737,9 +737,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response if !ok { continue } - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: reasoning.Thinking, - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: &ReasoningOptionMetadata{ Signature: reasoning.Signature, }, @@ -750,9 +750,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response if !ok { continue } - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: "", - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: &ReasoningOptionMetadata{ RedactedData: reasoning.Data, }, @@ -763,7 +763,7 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response if !ok { continue } - content = append(content, ai.ToolCallContent{ + content = append(content, fantasy.ToolCallContent{ ToolCallID: toolUse.ID, ToolName: toolUse.Name, Input: string(toolUse.Input), @@ -772,9 +772,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response } } - return &ai.Response{ + return &fantasy.Response{ Content: content, - Usage: ai.Usage{ + Usage: fantasy.Usage{ InputTokens: response.Usage.InputTokens, OutputTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, @@ -782,13 +782,13 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response CacheReadTokens: response.Usage.CacheReadInputTokens, }, FinishReason: mapFinishReason(string(response.StopReason)), - ProviderMetadata: ai.ProviderMetadata{}, + ProviderMetadata: fantasy.ProviderMetadata{}, Warnings: warnings, }, nil } -// Stream implements ai.LanguageModel. -func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +// Stream implements fantasy.LanguageModel. +func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings, err := a.prepareParams(call) if err != nil { return nil, err @@ -796,10 +796,10 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo stream := a.client.Messages.NewStreaming(ctx, *params) acc := anthropic.Message{} - return func(yield func(ai.StreamPart) bool) { + return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeWarnings, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeWarnings, Warnings: warnings, }) { return @@ -814,24 +814,24 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo contentBlockType := chunk.ContentBlock.Type switch contentBlockType { case "text": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextStart, ID: fmt.Sprintf("%d", chunk.Index), }) { return } case "thinking": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, ID: fmt.Sprintf("%d", chunk.Index), }) { return } case "redacted_thinking": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, ID: fmt.Sprintf("%d", chunk.Index), - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: &ReasoningOptionMetadata{ RedactedData: chunk.ContentBlock.Data, }, @@ -840,8 +840,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo return } case "tool_use": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, ID: chunk.ContentBlock.ID, ToolCallName: chunk.ContentBlock.Name, ToolCallInput: "", @@ -856,28 +856,28 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo contentBlock := acc.Content[int(chunk.Index)] switch contentBlock.Type { case "text": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: fmt.Sprintf("%d", chunk.Index), }) { return } case "thinking": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: fmt.Sprintf("%d", chunk.Index), }) { return } case "tool_use": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, ID: contentBlock.ID, }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, ID: contentBlock.ID, ToolCallName: contentBlock.Name, ToolCallInput: string(contentBlock.Input), @@ -888,26 +888,26 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo case "content_block_delta": switch chunk.Delta.Type { case "text_delta": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, ID: fmt.Sprintf("%d", chunk.Index), Delta: chunk.Delta.Text, }) { return } case "thinking_delta": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", chunk.Index), Delta: chunk.Delta.Thinking, }) { return } case "signature_delta": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", chunk.Index), - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: &ReasoningOptionMetadata{ Signature: chunk.Delta.Signature, }, @@ -920,8 +920,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo continue } contentBlock := acc.Content[int(chunk.Index)] - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, ID: contentBlock.ID, ToolCallInput: chunk.Delta.PartialJSON, }) { @@ -934,23 +934,23 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo err := stream.Err() if err == nil || errors.Is(err, io.EOF) { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeFinish, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, ID: acc.ID, FinishReason: mapFinishReason(string(acc.StopReason)), - Usage: ai.Usage{ + Usage: fantasy.Usage{ InputTokens: acc.Usage.InputTokens, OutputTokens: acc.Usage.OutputTokens, TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens, CacheCreationTokens: acc.Usage.CacheCreationInputTokens, CacheReadTokens: acc.Usage.CacheReadInputTokens, }, - ProviderMetadata: ai.ProviderMetadata{}, + ProviderMetadata: fantasy.ProviderMetadata{}, }) return } else { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: a.handleError(err), }) return diff --git a/anthropic/provider_options.go b/anthropic/provider_options.go index 2d7f424ceee4e6ba0b630147a8b893dc83f11e7e..c45c7c25e40c782ac0f03a4fb14d925e574b6f02 100644 --- a/anthropic/provider_options.go +++ b/anthropic/provider_options.go @@ -1,6 +1,6 @@ package anthropic -import "charm.land/fantasy/ai" +import "charm.land/fantasy" type ProviderOptions struct { SendReasoning *bool `json:"send_reasoning"` @@ -31,21 +31,21 @@ type CacheControl struct { Type string `json:"type"` } -func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } -func NewProviderCacheControlOptions(opts *ProviderCacheControlOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewProviderCacheControlOptions(opts *ProviderCacheControlOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions - if err := ai.ParseOptions(data, &options); err != nil { + if err := fantasy.ParseOptions(data, &options); err != nil { return nil, err } return &options, nil diff --git a/azure/azure.go b/azure/azure.go index dfe9ae0dee701e07af2334b7cacbc117511dc539..ebb11344226d48289dfc204830935b94a274a6f9 100644 --- a/azure/azure.go +++ b/azure/azure.go @@ -1,7 +1,7 @@ package azure import ( - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openaicompat" "github.com/openai/openai-go/v2/azure" "github.com/openai/openai-go/v2/option" @@ -22,7 +22,7 @@ const ( type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { o := options{ apiVersion: defaultAPIVersion, } diff --git a/bedrock/bedrock.go b/bedrock/bedrock.go index 6d4bb2ddf9214955f058e97e301dc715f9d04d46..68f12a42b27a42eee93eadf9b8408daf3e9a9333 100644 --- a/bedrock/bedrock.go +++ b/bedrock/bedrock.go @@ -1,7 +1,7 @@ package bedrock import ( - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/anthropic" "github.com/anthropics/anthropic-sdk-go/option" ) @@ -17,7 +17,7 @@ const ( type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { var o options for _, opt := range opts { opt(&o) diff --git a/ai/content.go b/content.go similarity index 99% rename from ai/content.go rename to content.go index cc431bbdea55ba9df8775ac471c621b127515ce3..82aabc210056c477feef867629efaa885421a56c 100644 --- a/ai/content.go +++ b/content.go @@ -1,4 +1,4 @@ -package ai +package fantasy // ProviderOptionsData is an interface for provider-specific options data. type ProviderOptionsData interface { diff --git a/ai/errors.go b/errors.go similarity index 99% rename from ai/errors.go rename to errors.go index 293f4e1f6abfd3c5b710a53cc6089094bb05f324..515a72730214410cef1f5533e297b9d31ede61de 100644 --- a/ai/errors.go +++ b/errors.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "encoding/json" @@ -7,7 +7,7 @@ import ( ) // markerSymbol is used for identifying AI SDK Error instances. -var markerSymbol = "ai.error" +var markerSymbol = "fantasy.error" // AIError is a custom error type for AI SDK related errors. type AIError struct { diff --git a/examples/agent/main.go b/examples/agent/main.go index 79753b5184fe5a131898b80f73bf5220afd6f03c..eed4455f86fe96f8a2c95fa18956aed69c5805d5 100644 --- a/examples/agent/main.go +++ b/examples/agent/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openrouter" ) @@ -24,21 +24,21 @@ func main() { Location string `json:"location" description:"the city"` } - weatherTool := ai.NewAgentTool( + weatherTool := fantasy.NewAgentTool( "weather", "Get weather information for a location", - func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) { - return ai.NewTextResponse("40 C"), nil + func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("40 C"), nil }, ) - agent := ai.NewAgent( + agent := fantasy.NewAgent( model, - ai.WithSystemPrompt("You are a helpful assistant"), - ai.WithTools(weatherTool), + fantasy.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithTools(weatherTool), ) - result, err := agent.Generate(context.Background(), ai.AgentCall{ + result, err := agent.Generate(context.Background(), fantasy.AgentCall{ Prompt: "What's the weather in pristina", }) if err != nil { @@ -49,8 +49,8 @@ func main() { fmt.Println("Steps: ", len(result.Steps)) for _, s := range result.Steps { for _, c := range s.Content { - if c.GetType() == ai.ContentTypeToolCall { - tc, _ := ai.AsContentType[ai.ToolCallContent](c) + if c.GetType() == fantasy.ContentTypeToolCall { + tc, _ := fantasy.AsContentType[fantasy.ToolCallContent](c) fmt.Println("ToolCall: ", tc.ToolName) } } diff --git a/examples/simple/main.go b/examples/simple/main.go index fdaf17810a2f6da01387d2f21e4c84ea782b4f45..1703b1ffc03968f7629c7c9bb530a70d436e4377 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/anthropic" ) @@ -17,11 +17,11 @@ func main() { os.Exit(1) } - response, err := model.Generate(context.Background(), ai.Call{ - Prompt: ai.Prompt{ - ai.NewUserMessage("Hello"), + response, err := model.Generate(context.Background(), fantasy.Call{ + Prompt: fantasy.Prompt{ + fantasy.NewUserMessage("Hello"), }, - Temperature: ai.Opt(0.7), + Temperature: fantasy.Opt(0.7), }) if err != nil { fmt.Println(err) diff --git a/examples/stream/main.go b/examples/stream/main.go index a9bc6cbe5658b039c25f7cc0eb192f0eb85fa3ab..eb8f5bbe3b088d3efa1675ded995272f76decb5c 100644 --- a/examples/stream/main.go +++ b/examples/stream/main.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" ) @@ -18,13 +18,13 @@ func main() { os.Exit(1) } - stream, err := model.Stream(context.Background(), ai.Call{ - Prompt: ai.Prompt{ - ai.NewUserMessage("Whats the weather in pristina."), + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: fantasy.Prompt{ + fantasy.NewUserMessage("Whats the weather in pristina."), }, - Temperature: ai.Opt(0.7), - Tools: []ai.Tool{ - ai.FunctionTool{ + Temperature: fantasy.Opt(0.7), + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ Name: "weather", Description: "Gets the weather for a location", InputSchema: map[string]any{ diff --git a/examples/streaming-agent-simple/main.go b/examples/streaming-agent-simple/main.go index 0b9f87c60e37c6bde67980fe91ecb7d6a830cce5..3a1df59015ca5a5617781ea19afdf9159cef567f 100644 --- a/examples/streaming-agent-simple/main.go +++ b/examples/streaming-agent-simple/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" ) @@ -32,19 +32,19 @@ func main() { Message string `json:"message" description:"The message to echo back"` } - echoTool := ai.NewAgentTool( + echoTool := fantasy.NewAgentTool( "echo", "Echo back the provided message", - func(ctx context.Context, input EchoInput, _ ai.ToolCall) (ai.ToolResponse, error) { - return ai.NewTextResponse("Echo: " + input.Message), nil + func(ctx context.Context, input EchoInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("Echo: " + input.Message), nil }, ) // Create streaming agent - agent := ai.NewAgent( + agent := fantasy.NewAgent( model, - ai.WithSystemPrompt("You are a helpful assistant."), - ai.WithTools(echoTool), + fantasy.WithSystemPrompt("You are a helpful assistant."), + fantasy.WithTools(echoTool), ) ctx := context.Background() @@ -54,7 +54,7 @@ func main() { fmt.Println() // Basic streaming with key callbacks - streamCall := ai.AgentStreamCall{ + streamCall := fantasy.AgentStreamCall{ Prompt: "Please echo back 'Hello, streaming world!'", // Show real-time text as it streams @@ -64,19 +64,19 @@ func main() { }, // Show when tools are called - OnToolCall: func(toolCall ai.ToolCallContent) error { + OnToolCall: func(toolCall fantasy.ToolCallContent) error { fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName) return nil }, // Show tool results - OnToolResult: func(result ai.ToolResultContent) error { + OnToolResult: func(result fantasy.ToolResultContent) error { fmt.Printf("[Tool result received]\n") return nil }, // Show when each step completes - OnStepFinish: func(step ai.StepResult) error { + OnStepFinish: func(step fantasy.StepResult) error { fmt.Printf("\n[Step completed: %s]\n", step.FinishReason) return nil }, diff --git a/examples/streaming-agent/main.go b/examples/streaming-agent/main.go index 16244621b97878569ba75ae8cad0a4fbde4a1ddc..710c94ccfa5b3b608ba23c500881103c61bda9e2 100644 --- a/examples/streaming-agent/main.go +++ b/examples/streaming-agent/main.go @@ -6,7 +6,7 @@ import ( "os" "strings" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/anthropic" ) @@ -42,10 +42,10 @@ func main() { } // Create weather tool using the new type-safe API - weatherTool := ai.NewAgentTool( + weatherTool := fantasy.NewAgentTool( "get_weather", "Get the current weather for a specific location", - func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) { + func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { // Simulate weather lookup with some fake data location := input.Location if location == "" { @@ -78,33 +78,33 @@ func main() { } weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp) - return ai.NewTextResponse(weather), nil + return fantasy.NewTextResponse(weather), nil }, ) // Create calculator tool using the new type-safe API - calculatorTool := ai.NewAgentTool( + calculatorTool := fantasy.NewAgentTool( "calculate", "Perform basic mathematical calculations", - func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) { + func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { // Simple calculator simulation expr := strings.TrimSpace(input.Expression) if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") { - return ai.NewTextResponse("2 + 2 = 4"), nil + return fantasy.NewTextResponse("2 + 2 = 4"), nil } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") { - return ai.NewTextResponse("10 * 5 = 50"), nil + return fantasy.NewTextResponse("10 * 5 = 50"), nil } else if strings.Contains(expr, "15 + 27") || strings.Contains(expr, "15+27") { - return ai.NewTextResponse("15 + 27 = 42"), nil + return fantasy.NewTextResponse("15 + 27 = 42"), nil } - return ai.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil + return fantasy.NewTextResponse("I can calculate simple expressions like '2 + 2', '10 * 5', or '15 + 27'"), nil }, ) // Create agent with tools - agent := ai.NewAgent( + agent := fantasy.NewAgent( model, - ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."), - ai.WithTools(weatherTool, calculatorTool), + fantasy.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."), + fantasy.WithTools(weatherTool, calculatorTool), ) ctx := context.Background() @@ -119,14 +119,14 @@ func main() { var reasoningBuffer strings.Builder // Create streaming call with all callbacks - streamCall := ai.AgentStreamCall{ + streamCall := fantasy.AgentStreamCall{ Prompt: "What's the weather in Pristina and what's 2 + 2?", // Agent-level callbacks OnAgentStart: func() { fmt.Println("๐ŸŽฌ Agent started") }, - OnAgentFinish: func(result *ai.AgentResult) error { + OnAgentFinish: func(result *fantasy.AgentResult) error { fmt.Printf("๐Ÿ Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens) return nil }, @@ -135,11 +135,11 @@ func main() { fmt.Printf("๐Ÿ“ Step %d started\n", stepNumber+1) return nil }, - OnStepFinish: func(stepResult ai.StepResult) error { + OnStepFinish: func(stepResult fantasy.StepResult) error { fmt.Printf("โœ… Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens) return nil }, - OnFinish: func(result *ai.AgentResult) { + OnFinish: func(result *fantasy.AgentResult) { fmt.Printf("๐ŸŽฏ Final result ready with %d steps\n", len(result.Steps)) }, OnError: func(err error) { @@ -147,7 +147,7 @@ func main() { }, // Stream part callbacks - OnWarnings: func(warnings []ai.CallWarning) error { + OnWarnings: func(warnings []fantasy.CallWarning) error { for _, warning := range warnings { fmt.Printf("โš ๏ธ Warning: %s\n", warning.Message) } @@ -166,7 +166,7 @@ func main() { fmt.Println() return nil }, - OnReasoningStart: func(id string, _ ai.ReasoningContent) error { + OnReasoningStart: func(id string, _ fantasy.ReasoningContent) error { fmt.Print("๐Ÿค” Thinking: ") return nil }, @@ -174,7 +174,7 @@ func main() { reasoningBuffer.WriteString(text) return nil }, - OnReasoningEnd: func(id string, content ai.ReasoningContent) error { + OnReasoningEnd: func(id string, content fantasy.ReasoningContent) error { if reasoningBuffer.Len() > 0 { fmt.Printf("%s\n", reasoningBuffer.String()) reasoningBuffer.Reset() @@ -193,26 +193,26 @@ func main() { // Tool input complete return nil }, - OnToolCall: func(toolCall ai.ToolCallContent) error { + OnToolCall: func(toolCall fantasy.ToolCallContent) error { fmt.Printf("๐Ÿ› ๏ธ Tool call: %s\n", toolCall.ToolName) fmt.Printf(" Input: %s\n", toolCall.Input) return nil }, - OnToolResult: func(result ai.ToolResultContent) error { + OnToolResult: func(result fantasy.ToolResultContent) error { fmt.Printf("๐ŸŽฏ Tool result from %s:\n", result.ToolName) switch output := result.Result.(type) { - case ai.ToolResultOutputContentText: + case fantasy.ToolResultOutputContentText: fmt.Printf(" %s\n", output.Text) - case ai.ToolResultOutputContentError: + case fantasy.ToolResultOutputContentError: fmt.Printf(" Error: %s\n", output.Error.Error()) } return nil }, - OnSource: func(source ai.SourceContent) error { + OnSource: func(source fantasy.SourceContent) error { fmt.Printf("๐Ÿ“š Source: %s (%s)\n", source.Title, source.URL) return nil }, - OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) error { + OnStreamFinish: func(usage fantasy.Usage, finishReason fantasy.FinishReason, providerMetadata fantasy.ProviderMetadata) error { fmt.Printf("๐Ÿ“Š Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens) return nil }, diff --git a/google/google.go b/google/google.go index 3d5909baa41644456fe86c1fc73c00b8ee900d2c..5082c6614bec187e966583efd366d8c91bba78a4 100644 --- a/google/google.go +++ b/google/google.go @@ -11,7 +11,7 @@ import ( "net/http" "strings" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/anthropic" "cloud.google.com/go/auth" "github.com/charmbracelet/x/exp/slice" @@ -39,7 +39,7 @@ type options struct { type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { options := options{ headers: map[string]string{}, } @@ -116,8 +116,8 @@ type languageModel struct { providerOptions options } -// LanguageModel implements ai.Provider. -func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { +// LanguageModel implements fantasy.Provider. +func (g *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) { if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") { return anthropic.New( anthropic.WithVertex(g.options.project, g.options.location), @@ -159,14 +159,14 @@ func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { }, nil } -func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) { +func (a languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) { config := &genai.GenerateContentConfig{} providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil) + return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil) } } @@ -176,8 +176,8 @@ func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig if providerOptions.ThinkingConfig.IncludeThoughts != nil && *providerOptions.ThinkingConfig.IncludeThoughts && strings.HasPrefix(a.provider, "google.vertex.") { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " + "and might not be supported or could behave unexpectedly with the current Google provider " + fmt.Sprintf("(%s)", a.provider), @@ -186,11 +186,11 @@ func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig if providerOptions.ThinkingConfig.ThinkingBudget != nil && *providerOptions.ThinkingConfig.ThinkingBudget < 128 { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default", }) - providerOptions.ThinkingConfig.ThinkingBudget = ai.Opt(int64(128)) + providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128)) } } @@ -271,15 +271,15 @@ func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig return config, content, warnings, nil } -func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam +func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam var systemInstructions *genai.Content var content []*genai.Content - var warnings []ai.CallWarning + var warnings []fantasy.CallWarning finishedSystemBlock := false for _, msg := range prompt { switch msg.Role { - case ai.MessageRoleSystem: + case fantasy.MessageRoleSystem: if finishedSystemBlock { // skip multiple system messages that are separated by user/assistant messages // TODO: see if we need to send error here? @@ -289,7 +289,7 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca var systemMessages []string for _, part := range msg.Content { - text, ok := ai.AsMessagePart[ai.TextPart](part) + text, ok := fantasy.AsMessagePart[fantasy.TextPart](part) if !ok || text.Text == "" { continue } @@ -304,20 +304,20 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca }, } } - case ai.MessageRoleUser: + case fantasy.MessageRoleUser: var parts []*genai.Part for _, part := range msg.Content { switch part.GetType() { - case ai.ContentTypeText: - text, ok := ai.AsMessagePart[ai.TextPart](part) + case fantasy.ContentTypeText: + text, ok := fantasy.AsMessagePart[fantasy.TextPart](part) if !ok || text.Text == "" { continue } parts = append(parts, &genai.Part{ Text: text.Text, }) - case ai.ContentTypeFile: - file, ok := ai.AsMessagePart[ai.FilePart](part) + case fantasy.ContentTypeFile: + file, ok := fantasy.AsMessagePart[fantasy.FilePart](part) if !ok { continue } @@ -337,20 +337,20 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca Parts: parts, }) } - case ai.MessageRoleAssistant: + case fantasy.MessageRoleAssistant: var parts []*genai.Part for _, part := range msg.Content { switch part.GetType() { - case ai.ContentTypeText: - text, ok := ai.AsMessagePart[ai.TextPart](part) + case fantasy.ContentTypeText: + text, ok := fantasy.AsMessagePart[fantasy.TextPart](part) if !ok || text.Text == "" { continue } parts = append(parts, &genai.Part{ Text: text.Text, }) - case ai.ContentTypeToolCall: - toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part) + case fantasy.ContentTypeToolCall: + toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part) if !ok { continue } @@ -375,20 +375,20 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca Parts: parts, }) } - case ai.MessageRoleTool: + case fantasy.MessageRoleTool: var parts []*genai.Part for _, part := range msg.Content { switch part.GetType() { - case ai.ContentTypeToolResult: - result, ok := ai.AsMessagePart[ai.ToolResultPart](part) + case fantasy.ContentTypeToolResult: + result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part) if !ok { continue } - var toolCall ai.ToolCallPart + var toolCall fantasy.ToolCallPart for _, m := range prompt { - if m.Role == ai.MessageRoleAssistant { + if m.Role == fantasy.MessageRoleAssistant { for _, content := range m.Content { - tc, ok := ai.AsMessagePart[ai.ToolCallPart](content) + tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content) if !ok { continue } @@ -400,8 +400,8 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca } } switch result.Output.GetType() { - case ai.ToolResultContentTypeText: - content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output) + case fantasy.ToolResultContentTypeText: + content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output) if !ok { continue } @@ -414,8 +414,8 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca }, }) - case ai.ToolResultContentTypeError: - content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output) + case fantasy.ToolResultContentTypeError: + content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output) if !ok { continue } @@ -443,8 +443,8 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca return systemInstructions, content, warnings } -// Generate implements ai.LanguageModel. -func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +// Generate implements fantasy.LanguageModel. +func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { config, contents, warnings, err := g.prepareParams(call) if err != nil { return nil, err @@ -468,18 +468,18 @@ func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Respons return mapResponse(response, warnings) } -// Model implements ai.LanguageModel. +// Model implements fantasy.LanguageModel. func (g *languageModel) Model() string { return g.modelID } -// Provider implements ai.LanguageModel. +// Provider implements fantasy.LanguageModel. func (g *languageModel) Provider() string { return g.provider } -// Stream implements ai.LanguageModel. -func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +// Stream implements fantasy.LanguageModel. +func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { config, contents, warnings, err := g.prepareParams(call) if err != nil { return nil, err @@ -495,10 +495,10 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp return nil, err } - return func(yield func(ai.StreamPart) bool) { + return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeWarnings, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeWarnings, Warnings: warnings, }) { return @@ -506,19 +506,19 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp } var currentContent string - var toolCalls []ai.ToolCallContent + var toolCalls []fantasy.ToolCallContent var isActiveText bool var isActiveReasoning bool var blockCounter int var currentTextBlockID string var currentReasoningBlockID string - var usage ai.Usage - var lastFinishReason ai.FinishReason + var usage fantasy.Usage + var lastFinishReason fantasy.FinishReason for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) { if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: err, }) return @@ -535,8 +535,8 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp // End any active text block before starting reasoning if isActiveText { isActiveText = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: currentTextBlockID, }) { return @@ -548,16 +548,16 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp isActiveReasoning = true currentReasoningBlockID = fmt.Sprintf("%d", blockCounter) blockCounter++ - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, ID: currentReasoningBlockID, }) { return } } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: currentReasoningBlockID, Delta: delta, }) { @@ -568,8 +568,8 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp // End any active reasoning block before starting text if isActiveReasoning { isActiveReasoning = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: currentReasoningBlockID, }) { return @@ -581,16 +581,16 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp isActiveText = true currentTextBlockID = fmt.Sprintf("%d", blockCounter) blockCounter++ - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextStart, ID: currentTextBlockID, }) { return } } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, ID: currentTextBlockID, Delta: delta, }) { @@ -603,8 +603,8 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp // End any active text or reasoning blocks if isActiveText { isActiveText = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: currentTextBlockID, }) { return @@ -612,8 +612,8 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp } if isActiveReasoning { isActiveReasoning = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: currentReasoningBlockID, }) { return @@ -624,38 +624,38 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp args, err := json.Marshal(part.FunctionCall.Args) if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: err, }) return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, ID: toolCallID, ToolCallName: part.FunctionCall.Name, }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, ID: toolCallID, Delta: string(args), }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, ID: toolCallID, }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, ID: toolCallID, ToolCallName: part.FunctionCall.Name, ToolCallInput: string(args), @@ -664,7 +664,7 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp return } - toolCalls = append(toolCalls, ai.ToolCallContent{ + toolCalls = append(toolCalls, fantasy.ToolCallContent{ ToolCallID: toolCallID, ToolName: part.FunctionCall.Name, Input: string(args), @@ -685,16 +685,16 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp // Close any open blocks before finishing if isActiveText { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: currentTextBlockID, }) { return } } if isActiveReasoning { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: currentReasoningBlockID, }) { return @@ -703,23 +703,23 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp finishReason := lastFinishReason if len(toolCalls) > 0 { - finishReason = ai.FinishReasonToolCalls + finishReason = fantasy.FinishReasonToolCalls } else if finishReason == "" { - finishReason = ai.FinishReasonStop + finishReason = fantasy.FinishReasonStop } - yield(ai.StreamPart{ - Type: ai.StreamPartTypeFinish, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, Usage: usage, FinishReason: finishReason, }) }, nil } -func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) { +func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) { for _, tool := range tools { - if tool.GetType() == ai.ToolTypeFunction { - ft, ok := tool.(ai.FunctionTool) + if tool.GetType() == fantasy.ToolTypeFunction { + ft, ok := tool.(fantasy.FunctionTool) if !ok { continue } @@ -747,8 +747,8 @@ func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*g continue } // TODO: handle provider tool calls - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedTool, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedTool, Tool: tool, Message: "tool is not supported", }) @@ -757,19 +757,19 @@ func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*g return googleTools, googleToolChoice, warnings } switch *toolChoice { - case ai.ToolChoiceAuto: + case fantasy.ToolChoiceAuto: googleToolChoice = &genai.ToolConfig{ FunctionCallingConfig: &genai.FunctionCallingConfig{ Mode: genai.FunctionCallingConfigModeAuto, }, } - case ai.ToolChoiceRequired: + case fantasy.ToolChoiceRequired: googleToolChoice = &genai.ToolConfig{ FunctionCallingConfig: &genai.FunctionCallingConfig{ Mode: genai.FunctionCallingConfigModeAny, }, } - case ai.ToolChoiceNone: + case fantasy.ToolChoiceNone: googleToolChoice = &genai.ToolConfig{ FunctionCallingConfig: &genai.FunctionCallingConfig{ Mode: genai.FunctionCallingConfigModeNone, @@ -862,14 +862,14 @@ func mapJSONTypeToGoogle(jsonType string) genai.Type { } } -func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) { +func mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) { if len(response.Candidates) == 0 || response.Candidates[0].Content == nil { return nil, errors.New("no response from model") } var ( - content []ai.Content - finishReason ai.FinishReason + content []fantasy.Content + finishReason fantasy.FinishReason hasToolCalls bool candidate = response.Candidates[0] ) @@ -878,9 +878,9 @@ func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarn switch { case part.Text != "": if part.Thought { - content = append(content, ai.ReasoningContent{Text: part.Text}) + content = append(content, fantasy.ReasoningContent{Text: part.Text}) } else { - content = append(content, ai.TextContent{Text: part.Text}) + content = append(content, fantasy.TextContent{Text: part.Text}) } case part.FunctionCall != nil: input, err := json.Marshal(part.FunctionCall.Args) @@ -888,7 +888,7 @@ func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarn return nil, err } toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString()) - content = append(content, ai.ToolCallContent{ + content = append(content, fantasy.ToolCallContent{ ToolCallID: toolCallID, ToolName: part.FunctionCall.Name, Input: string(input), @@ -902,12 +902,12 @@ func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarn } if hasToolCalls { - finishReason = ai.FinishReasonToolCalls + finishReason = fantasy.FinishReasonToolCalls } else { finishReason = mapFinishReason(candidate.FinishReason) } - return &ai.Response{ + return &fantasy.Response{ Content: content, Usage: mapUsage(response.UsageMetadata), FinishReason: finishReason, @@ -915,31 +915,31 @@ func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarn }, nil } -func mapFinishReason(reason genai.FinishReason) ai.FinishReason { +func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason { switch reason { case genai.FinishReasonStop: - return ai.FinishReasonStop + return fantasy.FinishReasonStop case genai.FinishReasonMaxTokens: - return ai.FinishReasonLength + return fantasy.FinishReasonLength case genai.FinishReasonSafety, genai.FinishReasonBlocklist, genai.FinishReasonProhibitedContent, genai.FinishReasonSPII, genai.FinishReasonImageSafety: - return ai.FinishReasonContentFilter + return fantasy.FinishReasonContentFilter case genai.FinishReasonRecitation, genai.FinishReasonLanguage, genai.FinishReasonMalformedFunctionCall: - return ai.FinishReasonError + return fantasy.FinishReasonError case genai.FinishReasonOther: - return ai.FinishReasonOther + return fantasy.FinishReasonOther default: - return ai.FinishReasonUnknown + return fantasy.FinishReasonUnknown } } -func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage { - return ai.Usage{ +func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage { + return fantasy.Usage{ InputTokens: int64(usage.ToolUsePromptTokenCount), OutputTokens: int64(usage.CandidatesTokenCount), TotalTokens: int64(usage.TotalTokenCount), diff --git a/google/provider_options.go b/google/provider_options.go index eed37501643772c41974d628dc94f61280c0df57..27843220397a6c087a4171e17d24420904656a16 100644 --- a/google/provider_options.go +++ b/google/provider_options.go @@ -1,6 +1,6 @@ package google -import "charm.land/fantasy/ai" +import "charm.land/fantasy" type ThinkingConfig struct { ThinkingBudget *int64 `json:"thinking_budget"` @@ -47,7 +47,7 @@ func (o *ProviderOptions) Options() {} func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions - if err := ai.ParseOptions(data, &options); err != nil { + if err := fantasy.ParseOptions(data, &options); err != nil { return nil, err } return &options, nil diff --git a/ai/model.go b/model.go similarity index 99% rename from ai/model.go rename to model.go index 01fb3dd818e7b9cfd9a4ca2361d6a24b6557d0b4..f08559d5f5979bdf6079e34ab613eaf42f401f7d 100644 --- a/ai/model.go +++ b/model.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "context" diff --git a/openai/language_model.go b/openai/language_model.go index e5ac86db2f2a4fbc52a28112987c293f4dac3f8d..6cc4de799ef1583e6e0f9887fae241f7a632c32d 100644 --- a/openai/language_model.go +++ b/openai/language_model.go @@ -9,7 +9,7 @@ import ( "io" "strings" - "charm.land/fantasy/ai" + "charm.land/fantasy" xjson "github.com/charmbracelet/x/json" "github.com/google/uuid" "github.com/openai/openai-go/v2" @@ -93,22 +93,22 @@ type streamToolCall struct { hasFinished bool } -// Model implements ai.LanguageModel. +// Model implements fantasy.LanguageModel. func (o languageModel) Model() string { return o.modelID } -// Provider implements ai.LanguageModel. +// Provider implements fantasy.LanguageModel. func (o languageModel) Provider() string { return o.provider } -func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { +func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) { params := &openai.ChatCompletionNewParams{} messages, warnings := toPrompt(call.Prompt) if call.TopK != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "top_k", }) } @@ -134,32 +134,32 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar // see https://platform.openai.com/docs/guides/reasoning#limitations if call.Temperature != nil { params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "temperature", Details: "temperature is not supported for reasoning models", }) } if call.TopP != nil { params.TopP = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "TopP", Details: "TopP is not supported for reasoning models", }) } if call.FrequencyPenalty != nil { params.FrequencyPenalty = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "FrequencyPenalty", Details: "FrequencyPenalty is not supported for reasoning models", }) } if call.PresencePenalty != nil { params.PresencePenalty = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "PresencePenalty", Details: "PresencePenalty is not supported for reasoning models", }) @@ -178,8 +178,8 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar if isSearchPreviewModel(o.modelID) { if call.Temperature != nil { params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "temperature", Details: "temperature is not supported for the search preview models and has been removed.", }) @@ -219,7 +219,7 @@ func (o languageModel) handleError(err error) error { v := h[len(h)-1] headers[strings.ToLower(k)] = v } - return ai.NewAPICallError( + return fantasy.NewAPICallError( apiErr.Message, apiErr.Request.URL.String(), string(requestDump), @@ -233,8 +233,8 @@ func (o languageModel) handleError(err error) error { return err } -// Generate implements ai.LanguageModel. -func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +// Generate implements fantasy.LanguageModel. +func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings, err := o.prepareParams(call) if err != nil { return nil, err @@ -248,10 +248,10 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response return nil, errors.New("no response generated") } choice := response.Choices[0] - content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations)) + content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations)) text := choice.Message.Content if text != "" { - content = append(content, ai.TextContent{ + content = append(content, fantasy.TextContent{ Text: text, }) } @@ -261,7 +261,7 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response } for _, tc := range choice.Message.ToolCalls { toolCallID := tc.ID - content = append(content, ai.ToolCallContent{ + content = append(content, fantasy.ToolCallContent{ ProviderExecuted: false, // TODO: update when handling other tools ToolCallID: toolCallID, ToolName: tc.Function.Name, @@ -271,8 +271,8 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response // Handle annotations/citations for _, annotation := range choice.Message.Annotations { if annotation.Type == "url_citation" { - content = append(content, ai.SourceContent{ - SourceType: ai.SourceTypeURL, + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeURL, ID: uuid.NewString(), URL: annotation.URLCitation.URL, Title: annotation.URLCitation.Title, @@ -284,21 +284,21 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason) if len(choice.Message.ToolCalls) > 0 { - mappedFinishReason = ai.FinishReasonToolCalls + mappedFinishReason = fantasy.FinishReasonToolCalls } - return &ai.Response{ + return &fantasy.Response{ Content: content, Usage: usage, FinishReason: mappedFinishReason, - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: providerMetadata, }, Warnings: warnings, }, nil } -// Stream implements ai.LanguageModel. -func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +// Stream implements fantasy.LanguageModel. +func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings, err := o.prepareParams(call) if err != nil { return nil, err @@ -313,17 +313,17 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo toolCalls := make(map[int64]streamToolCall) // Build provider metadata for streaming - providerMetadata := ai.ProviderMetadata{ + providerMetadata := fantasy.ProviderMetadata{ Name: &ProviderMetadata{}, } acc := openai.ChatCompletionAccumulator{} extraContext := make(map[string]any) - var usage ai.Usage + var usage fantasy.Usage var finishReason string - return func(yield func(ai.StreamPart) bool) { + return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeWarnings, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeWarnings, Warnings: warnings, }) { return @@ -344,15 +344,15 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo case choice.Delta.Content != "": if !isActiveText { isActiveText = true - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextStart, ID: "0", }) { return } } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, ID: "0", Delta: choice.Delta.Content, }) { @@ -361,8 +361,8 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo case len(choice.Delta.ToolCalls) > 0: if isActiveText { isActiveText = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: "0", }) { return @@ -377,8 +377,8 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo if toolCallDelta.Function.Arguments != "" { existingToolCall.arguments += toolCallDelta.Function.Arguments } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, ID: existingToolCall.id, Delta: toolCallDelta.Function.Arguments, }) { @@ -386,15 +386,15 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo } toolCalls[toolCallDelta.Index] = existingToolCall if xjson.IsValid(existingToolCall.arguments) { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, ID: existingToolCall.id, }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, ID: existingToolCall.id, ToolCallName: existingToolCall.name, ToolCallInput: existingToolCall.arguments, @@ -408,24 +408,24 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo // Does not exist var err error if toolCallDelta.Type != "function" { - err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.") + err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.") } if toolCallDelta.ID == "" { - err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.") + err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.") } if toolCallDelta.Function.Name == "" { - err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.") + err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.") } if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: o.handleError(stream.Err()), }) return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, ID: toolCallDelta.ID, ToolCallName: toolCallDelta.Function.Name, }) { @@ -439,23 +439,23 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo exTc := toolCalls[toolCallDelta.Index] if exTc.arguments != "" { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, ID: exTc.id, Delta: exTc.arguments, }) { return } if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, ID: toolCallDelta.ID, }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, ID: exTc.id, ToolCallName: exTc.name, ToolCallInput: exTc.arguments, @@ -485,10 +485,10 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 { for _, annotation := range annotations { if annotation.Type == "url_citation" { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeSource, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeSource, ID: uuid.NewString(), - SourceType: ai.SourceTypeURL, + SourceType: fantasy.SourceTypeURL, URL: annotation.URLCitation.URL, Title: annotation.URLCitation.Title, }) { @@ -504,8 +504,8 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo // finished if isActiveText { isActiveText = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: "0", }) { return @@ -520,10 +520,10 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo // Handle annotations/citations from accumulated response for _, annotation := range choice.Message.Annotations { if annotation.Type == "url_citation" { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeSource, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeSource, ID: acc.ID, - SourceType: ai.SourceTypeURL, + SourceType: fantasy.SourceTypeURL, URL: annotation.URLCitation.URL, Title: annotation.URLCitation.Title, }) { @@ -536,19 +536,19 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo if len(acc.Choices) > 0 { choice := acc.Choices[0] if len(choice.Message.ToolCalls) > 0 { - mappedFinishReason = ai.FinishReasonToolCalls + mappedFinishReason = fantasy.FinishReasonToolCalls } } - yield(ai.StreamPart{ - Type: ai.StreamPartTypeFinish, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, Usage: usage, FinishReason: mappedFinishReason, ProviderMetadata: providerMetadata, }) return } else { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: o.handleError(err), }) return @@ -574,10 +574,10 @@ func supportsPriorityProcessing(modelID string) bool { strings.HasPrefix(modelID, "o4-mini") } -func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) { +func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) { for _, tool := range tools { - if tool.GetType() == ai.ToolTypeFunction { - ft, ok := tool.(ai.FunctionTool) + if tool.GetType() == fantasy.ToolTypeFunction { + ft, ok := tool.(fantasy.FunctionTool) if !ok { continue } @@ -596,8 +596,8 @@ func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []op } // TODO: handle provider tool calls - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedTool, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedTool, Tool: tool, Message: "tool is not supported", }) @@ -607,11 +607,11 @@ func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []op } switch *toolChoice { - case ai.ToolChoiceAuto: + case fantasy.ToolChoiceAuto: openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ OfAuto: param.NewOpt("auto"), } - case ai.ToolChoiceNone: + case fantasy.ToolChoiceNone: openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ OfAuto: param.NewOpt("none"), } @@ -628,25 +628,25 @@ func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []op return openAiTools, openAiToolChoice, warnings } -func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { +func toPrompt(prompt fantasy.Prompt) ([]openai.ChatCompletionMessageParamUnion, []fantasy.CallWarning) { var messages []openai.ChatCompletionMessageParamUnion - var warnings []ai.CallWarning + var warnings []fantasy.CallWarning for _, msg := range prompt { switch msg.Role { - case ai.MessageRoleSystem: + case fantasy.MessageRoleSystem: var systemPromptParts []string for _, c := range msg.Content { - if c.GetType() != ai.ContentTypeText { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + if c.GetType() != fantasy.ContentTypeText { + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system prompt can only have text content", }) continue } - textPart, ok := ai.AsContentType[ai.TextPart](c) + textPart, ok := fantasy.AsContentType[fantasy.TextPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system prompt text part does not have the right type", }) continue @@ -657,20 +657,20 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. } } if len(systemPromptParts) == 0 { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system prompt has no text parts", }) continue } messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n"))) - case ai.MessageRoleUser: + case fantasy.MessageRoleUser: // simple user message just text content - if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText { - textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0]) + if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText { + textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0]) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "user message text part does not have the right type", }) continue @@ -685,11 +685,11 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. var content []openai.ChatCompletionContentPartUnionParam for _, c := range msg.Content { switch c.GetType() { - case ai.ContentTypeText: - textPart, ok := ai.AsContentType[ai.TextPart](c) + case fantasy.ContentTypeText: + textPart, ok := fantasy.AsContentType[fantasy.TextPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "user message text part does not have the right type", }) continue @@ -699,11 +699,11 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. Text: textPart.Text, }, }) - case ai.ContentTypeFile: - filePart, ok := ai.AsContentType[ai.FilePart](c) + case fantasy.ContentTypeFile: + filePart, ok := fantasy.AsContentType[fantasy.FilePart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "user message file part does not have the right type", }) continue @@ -781,21 +781,21 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. } default: - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType), }) } } } messages = append(messages, openai.UserMessage(content)) - case ai.MessageRoleAssistant: + case fantasy.MessageRoleAssistant: // simple assistant message just text content - if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText { - textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0]) + if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText { + textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0]) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "assistant message text part does not have the right type", }) continue @@ -808,11 +808,11 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. } for _, c := range msg.Content { switch c.GetType() { - case ai.ContentTypeText: - textPart, ok := ai.AsContentType[ai.TextPart](c) + case fantasy.ContentTypeText: + textPart, ok := fantasy.AsContentType[fantasy.TextPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "assistant message text part does not have the right type", }) continue @@ -820,11 +820,11 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ OfString: param.NewOpt(textPart.Text), } - case ai.ContentTypeToolCall: - toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c) + case fantasy.ContentTypeToolCall: + toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "assistant message tool part does not have the right type", }) continue @@ -845,42 +845,42 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. messages = append(messages, openai.ChatCompletionMessageParamUnion{ OfAssistant: &assistantMsg, }) - case ai.MessageRoleTool: + case fantasy.MessageRoleTool: for _, c := range msg.Content { - if c.GetType() != ai.ContentTypeToolResult { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + if c.GetType() != fantasy.ContentTypeToolResult { + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool message can only have tool result content", }) continue } - toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c) + toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool message result part does not have the right type", }) continue } switch toolResultPart.Output.GetType() { - case ai.ToolResultContentTypeText: - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output) + case fantasy.ToolResultContentTypeText: + output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool result output does not have the right type", }) continue } messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID)) - case ai.ToolResultContentTypeError: + case fantasy.ToolResultContentTypeError: // TODO: check if better handling is needed - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output) + output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool result output does not have the right type", }) continue diff --git a/openai/language_model_hooks.go b/openai/language_model_hooks.go index 2e5a294855c3d1ea915c328729d1be46ef6f1b0c..7d04c5d1b98439f02d4f865955dbbce9e3699c39 100644 --- a/openai/language_model_hooks.go +++ b/openai/language_model_hooks.go @@ -3,32 +3,32 @@ package openai import ( "fmt" - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" "github.com/openai/openai-go/v2/shared" ) type ( - LanguageModelPrepareCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) - LanguageModelMapFinishReasonFunc = func(finishReason string) ai.FinishReason - LanguageModelUsageFunc = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) - LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []ai.Content - LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) - LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) - LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata + LanguageModelPrepareCallFunc = func(model fantasy.LanguageModel, params *openai.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) + LanguageModelMapFinishReasonFunc = func(finishReason string) fantasy.FinishReason + LanguageModelUsageFunc = func(choice openai.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData) + LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []fantasy.Content + LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) + LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata) + LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata fantasy.ProviderMetadata) fantasy.ProviderMetadata ) -func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { +func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) { if call.ProviderOptions == nil { return nil, nil } - var warnings []ai.CallWarning + var warnings []fantasy.CallWarning providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) + return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) } } @@ -110,24 +110,24 @@ func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletio if isReasoningModel(model.Model()) { if providerOptions.LogitBias != nil { params.LogitBias = nil - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "LogitBias", Message: "LogitBias is not supported for reasoning models", }) } if providerOptions.LogProbs != nil { params.Logprobs = param.Opt[bool]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "Logprobs", Message: "Logprobs is not supported for reasoning models", }) } if providerOptions.TopLogProbs != nil { params.TopLogprobs = param.Opt[int64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "TopLogprobs", Message: "TopLogprobs is not supported for reasoning models", }) @@ -139,15 +139,15 @@ func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletio serviceTier := *providerOptions.ServiceTier if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) { params.ServiceTier = "" - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "ServiceTier", Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", }) } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) { params.ServiceTier = "" - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "ServiceTier", Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported", }) @@ -156,22 +156,22 @@ func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletio return warnings, nil } -func DefaultMapFinishReasonFunc(finishReason string) ai.FinishReason { +func DefaultMapFinishReasonFunc(finishReason string) fantasy.FinishReason { switch finishReason { case "stop": - return ai.FinishReasonStop + return fantasy.FinishReasonStop case "length": - return ai.FinishReasonLength + return fantasy.FinishReasonLength case "content_filter": - return ai.FinishReasonContentFilter + return fantasy.FinishReasonContentFilter case "function_call", "tool_calls": - return ai.FinishReasonToolCalls + return fantasy.FinishReasonToolCalls default: - return ai.FinishReasonUnknown + return fantasy.FinishReasonUnknown } } -func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) { +func DefaultUsageFunc(response openai.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData) { completionTokenDetails := response.Usage.CompletionTokensDetails promptTokenDetails := response.Usage.PromptTokensDetails @@ -192,7 +192,7 @@ func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOpti providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens } } - return ai.Usage{ + return fantasy.Usage{ InputTokens: response.Usage.PromptTokens, OutputTokens: response.Usage.CompletionTokens, TotalTokens: response.Usage.TotalTokens, @@ -201,9 +201,9 @@ func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOpti }, providerMetadata } -func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) { +func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata) { if chunk.Usage.TotalTokens == 0 { - return ai.Usage{}, nil + return fantasy.Usage{}, nil } streamProviderMetadata := &ProviderMetadata{} if metadata != nil { @@ -217,7 +217,7 @@ func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any // we do this here because the acc does not add prompt details completionTokenDetails := chunk.Usage.CompletionTokensDetails promptTokenDetails := chunk.Usage.PromptTokensDetails - usage := ai.Usage{ + usage := fantasy.Usage{ InputTokens: chunk.Usage.PromptTokens, OutputTokens: chunk.Usage.CompletionTokens, TotalTokens: chunk.Usage.TotalTokens, @@ -235,12 +235,12 @@ func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any } } - return usage, ai.ProviderMetadata{ + return usage, fantasy.ProviderMetadata{ Name: streamProviderMetadata, } } -func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata { +func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata fantasy.ProviderMetadata) fantasy.ProviderMetadata { streamProviderMetadata, ok := metadata[Name] if !ok { streamProviderMetadata = &ProviderMetadata{} diff --git a/openai/openai.go b/openai/openai.go index d18d74bacb1ec8585ffad425864e3f7446e063a3..96c6b2dab864ee3f8b8905eaea658349fe114e09 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -5,7 +5,7 @@ import ( "cmp" "maps" - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" ) @@ -34,7 +34,7 @@ type options struct { type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { providerOptions := options{ headers: map[string]string{}, languageModelOptions: make([]LanguageModelOption, 0), @@ -117,8 +117,8 @@ func WithUseResponsesAPI() Option { } } -// LanguageModel implements ai.Provider. -func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { +// LanguageModel implements fantasy.Provider. +func (o *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) { openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions)) if o.options.apiKey != "" { diff --git a/openai/openai_test.go b/openai/openai_test.go index 5c40c91e50c75ed1a7282eb452e3295c7d39d778..043cc287dc47943b649f1421532d79a3d208e0b8 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/openai/openai-go/v2/packages/param" "github.com/stretchr/testify/require" ) @@ -21,11 +21,11 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { t.Run("should forward system messages", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleSystem, - Content: []ai.MessagePart{ - ai.TextPart{Text: "You are a helpful assistant."}, + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "You are a helpful assistant."}, }, }, } @@ -43,10 +43,10 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { t.Run("should handle empty system messages", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleSystem, - Content: []ai.MessagePart{}, + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{}, }, } @@ -60,12 +60,12 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { t.Run("should join multiple system text parts", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleSystem, - Content: []ai.MessagePart{ - ai.TextPart{Text: "You are a helpful assistant."}, - ai.TextPart{Text: "Be concise."}, + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "You are a helpful assistant."}, + fantasy.TextPart{Text: "Be concise."}, }, }, } @@ -87,11 +87,11 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { t.Run("should convert messages with only a text part to a string content", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.TextPart{Text: "Hello"}, + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "Hello"}, }, }, } @@ -110,12 +110,12 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { t.Parallel() imageData := []byte{0, 1, 2, 3} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.TextPart{Text: "Hello"}, - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "Hello"}, + fantasy.FilePart{ MediaType: "image/png", Data: imageData, }, @@ -150,11 +150,11 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { t.Parallel() imageData := []byte{0, 1, 2, 3} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "image/png", Data: imageData, ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{ @@ -188,11 +188,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Run("should throw for unsupported mime types", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "application/something", Data: []byte("test"), }, @@ -211,11 +211,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Parallel() audioData := []byte{0, 1, 2, 3} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "audio/wav", Data: audioData, }, @@ -244,11 +244,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Parallel() audioData := []byte{0, 1, 2, 3} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "audio/mpeg", Data: audioData, }, @@ -272,11 +272,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Parallel() audioData := []byte{0, 1, 2, 3} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "audio/mp3", Data: audioData, }, @@ -300,11 +300,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Parallel() pdfData := []byte{1, 2, 3, 4, 5} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "application/pdf", Data: pdfData, Filename: "document.pdf", @@ -334,11 +334,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Parallel() pdfData := []byte{1, 2, 3, 4, 5} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "application/pdf", Data: pdfData, Filename: "document.pdf", @@ -364,11 +364,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "application/pdf", Data: []byte("file-pdf-12345"), }, @@ -394,11 +394,11 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { t.Parallel() pdfData := []byte{1, 2, 3, 4, 5} - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.FilePart{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.FilePart{ MediaType: "application/pdf", Data: pdfData, }, @@ -431,11 +431,11 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { outputResult := map[string]any{"oof": "321rab"} outputJSON, _ := json.Marshal(outputResult) - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleAssistant, - Content: []ai.MessagePart{ - ai.ToolCallPart{ + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ ToolCallID: "quux", ToolName: "thwomp", Input: string(inputJSON), @@ -443,11 +443,11 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { }, }, { - Role: ai.MessageRoleTool, - Content: []ai.MessagePart{ - ai.ToolResultPart{ + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ ToolCallID: "quux", - Output: ai.ToolResultOutputContentText{ + Output: fantasy.ToolResultOutputContentText{ Text: string(outputJSON), }, }, @@ -482,19 +482,19 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { t.Run("should handle different tool output types", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleTool, - Content: []ai.MessagePart{ - ai.ToolResultPart{ + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ ToolCallID: "text-tool", - Output: ai.ToolResultOutputContentText{ + Output: fantasy.ToolResultOutputContentText{ Text: "Hello world", }, }, - ai.ToolResultPart{ + fantasy.ToolResultPart{ ToolCallID: "error-tool", - Output: ai.ToolResultOutputContentError{ + Output: fantasy.ToolResultOutputContentError{ Error: errors.New("Something went wrong"), }, }, @@ -527,11 +527,11 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { t.Run("should handle simple text assistant messages", func(t *testing.T) { t.Parallel() - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleAssistant, - Content: []ai.MessagePart{ - ai.TextPart{Text: "Hello, how can I help you?"}, + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "Hello, how can I help you?"}, }, }, } @@ -552,12 +552,12 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { inputArgs := map[string]any{"query": "test"} inputJSON, _ := json.Marshal(inputArgs) - prompt := ai.Prompt{ + prompt := fantasy.Prompt{ { - Role: ai.MessageRoleAssistant, - Content: []ai.MessagePart{ - ai.TextPart{Text: "Let me search for that."}, - ai.ToolCallPart{ + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "Let me search for that."}, + fantasy.ToolCallPart{ ToolCallID: "call-123", ToolName: "search", Input: string(inputJSON), @@ -583,11 +583,11 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { }) } -var testPrompt = ai.Prompt{ +var testPrompt = fantasy.Prompt{ { - Role: ai.MessageRoleUser, - Content: []ai.MessagePart{ - ai.TextPart{Text: "Hello"}, + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "Hello"}, }, }, } @@ -815,14 +815,14 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) require.NoError(t, err) require.Len(t, result.Content, 1) - textContent, ok := result.Content[0].(ai.TextContent) + textContent, ok := result.Content[0].(fantasy.TextContent) require.True(t, ok) require.Equal(t, "Hello, World!", textContent.Text) }) @@ -847,7 +847,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -871,7 +871,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -911,7 +911,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -937,10 +937,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - LogProbs: ai.Opt(true), + LogProbs: fantasy.Opt(true), }), }) @@ -971,12 +971,12 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) require.NoError(t, err) - require.Equal(t, ai.FinishReasonStop, result.FinishReason) + require.Equal(t, fantasy.FinishReasonStop, result.FinishReason) }) t.Run("should support unknown finish reason", func(t *testing.T) { @@ -995,12 +995,12 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) require.NoError(t, err) - require.Equal(t, ai.FinishReasonUnknown, result.FinishReason) + require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason) }) t.Run("should pass the model and the messages", func(t *testing.T) { @@ -1019,7 +1019,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -1051,14 +1051,14 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ LogitBias: map[string]int64{ "50256": -100, }, - ParallelToolCalls: ai.Opt(false), - User: ai.Opt("test-user-id"), + ParallelToolCalls: fantasy.Opt(false), + User: fantasy.Opt("test-user-id"), }), }) @@ -1093,7 +1093,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("o1-mini") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions( &ProviderOptions{ @@ -1133,10 +1133,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - TextVerbosity: ai.Opt("low"), + TextVerbosity: fantasy.Opt("low"), }), }) @@ -1171,10 +1171,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, - Tools: []ai.Tool{ - ai.FunctionTool{ + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ Name: "test-tool", InputSchema: map[string]any{ "type": "object", @@ -1189,7 +1189,7 @@ func TestDoGenerate(t *testing.T) { }, }, }, - ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0], + ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0], }) require.NoError(t, err) @@ -1243,10 +1243,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, - Tools: []ai.Tool{ - ai.FunctionTool{ + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ Name: "test-tool", InputSchema: map[string]any{ "type": "object", @@ -1261,13 +1261,13 @@ func TestDoGenerate(t *testing.T) { }, }, }, - ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0], + ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0], }) require.NoError(t, err) require.Len(t, result.Content, 1) - toolCall, ok := result.Content[0].(ai.ToolCallContent) + toolCall, ok := result.Content[0].(fantasy.ToolCallContent) require.True(t, ok) require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID) require.Equal(t, "test-tool", toolCall.ToolName) @@ -1301,20 +1301,20 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) require.NoError(t, err) require.Len(t, result.Content, 2) - textContent, ok := result.Content[0].(ai.TextContent) + textContent, ok := result.Content[0].(fantasy.TextContent) require.True(t, ok) require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text) - sourceContent, ok := result.Content[1].(ai.SourceContent) + sourceContent, ok := result.Content[1].(fantasy.SourceContent) require.True(t, ok) - require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType) + require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType) require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL) require.Equal(t, "Document 1", sourceContent.Title) require.NotEmpty(t, sourceContent.ID) @@ -1343,7 +1343,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o-mini") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -1378,7 +1378,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o-mini") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -1406,7 +1406,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("o1-preview") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, Temperature: &[]float64{0.5}[0], TopP: &[]float64{0.7}[0], @@ -1435,7 +1435,7 @@ func TestDoGenerate(t *testing.T) { // Should have warnings require.Len(t, result.Warnings, 4) - require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) + require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) require.Equal(t, "temperature", result.Warnings[0].Setting) require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models") }) @@ -1454,7 +1454,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("o1-preview") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, MaxOutputTokens: &[]int64{1000}[0], }) @@ -1498,7 +1498,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("o1-preview") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -1525,10 +1525,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("o1-preview") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - MaxCompletionTokens: ai.Opt(int64(255)), + MaxCompletionTokens: fantasy.Opt(int64(255)), }), }) @@ -1563,7 +1563,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ Prediction: map[string]any{ @@ -1607,10 +1607,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - Store: ai.Opt(true), + Store: fantasy.Opt(true), }), }) @@ -1645,7 +1645,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ Metadata: map[string]any{ @@ -1687,10 +1687,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - PromptCacheKey: ai.Opt("test-cache-key-123"), + PromptCacheKey: fantasy.Opt("test-cache-key-123"), }), }) @@ -1725,10 +1725,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - SafetyIdentifier: ai.Opt("test-safety-identifier-123"), + SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"), }), }) @@ -1761,7 +1761,7 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o-search-preview") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, Temperature: &[]float64{0.7}[0], }) @@ -1774,7 +1774,7 @@ func TestDoGenerate(t *testing.T) { require.Nil(t, call.body["temperature"]) require.Len(t, result.Warnings, 1) - require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) + require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) require.Equal(t, "temperature", result.Warnings[0].Setting) require.Contains(t, result.Warnings[0].Details, "search preview models") }) @@ -1795,10 +1795,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("o3-mini") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - ServiceTier: ai.Opt("flex"), + ServiceTier: fantasy.Opt("flex"), }), }) @@ -1831,10 +1831,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o-mini") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - ServiceTier: ai.Opt("flex"), + ServiceTier: fantasy.Opt("flex"), }), }) @@ -1845,7 +1845,7 @@ func TestDoGenerate(t *testing.T) { require.Nil(t, call.body["service_tier"]) require.Len(t, result.Warnings, 1) - require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) + require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) require.Equal(t, "ServiceTier", result.Warnings[0].Setting) require.Contains(t, result.Warnings[0].Details, "flex processing is only available") }) @@ -1864,10 +1864,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o-mini") - _, err := model.Generate(context.Background(), ai.Call{ + _, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - ServiceTier: ai.Opt("priority"), + ServiceTier: fantasy.Opt("priority"), }), }) @@ -1900,10 +1900,10 @@ func TestDoGenerate(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - result, err := model.Generate(context.Background(), ai.Call{ + result, err := model.Generate(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - ServiceTier: ai.Opt("priority"), + ServiceTier: fantasy.Opt("priority"), }), }) @@ -1914,7 +1914,7 @@ func TestDoGenerate(t *testing.T) { require.Nil(t, call.body["service_tier"]) require.Len(t, result.Warnings, 1) - require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) + require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type) require.Equal(t, "ServiceTier", result.Warnings[0].Setting) require.Contains(t, result.Warnings[0].Details, "priority processing is only available") }) @@ -2167,14 +2167,14 @@ func (sms *streamingMockServer) prepareErrorStreamResponse() { sms.chunks = chunks } -func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) { - var parts []ai.StreamPart +func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) { + var parts []fantasy.StreamPart for part := range stream { parts = append(parts, part) - if part.Type == ai.StreamPartTypeError { + if part.Type == fantasy.StreamPartTypeError { break } - if part.Type == ai.StreamPartTypeFinish { + if part.Type == fantasy.StreamPartTypeFinish { break } } @@ -2207,7 +2207,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2225,13 +2225,13 @@ func TestDoStream(t *testing.T) { for i, part := range parts { switch part.Type { - case ai.StreamPartTypeTextStart: + case fantasy.StreamPartTypeTextStart: textStart = i - case ai.StreamPartTypeTextDelta: + case fantasy.StreamPartTypeTextDelta: deltas = append(deltas, part.Delta) - case ai.StreamPartTypeTextEnd: + case fantasy.StreamPartTypeTextEnd: textEnd = i - case ai.StreamPartTypeFinish: + case fantasy.StreamPartTypeFinish: finish = i } } @@ -2243,7 +2243,7 @@ func TestDoStream(t *testing.T) { // Check finish part finishPart := parts[finish] - require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason) + require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason) require.Equal(t, int64(17), finishPart.Usage.InputTokens) require.Equal(t, int64(227), finishPart.Usage.OutputTokens) require.Equal(t, int64(244), finishPart.Usage.TotalTokens) @@ -2263,10 +2263,10 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, - Tools: []ai.Tool{ - ai.FunctionTool{ + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ Name: "test-tool", InputSchema: map[string]any{ "type": "object", @@ -2294,15 +2294,15 @@ func TestDoStream(t *testing.T) { for i, part := range parts { switch part.Type { - case ai.StreamPartTypeToolInputStart: + case fantasy.StreamPartTypeToolInputStart: toolInputStart = i require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID) require.Equal(t, "test-tool", part.ToolCallName) - case ai.StreamPartTypeToolInputDelta: + case fantasy.StreamPartTypeToolInputDelta: toolDeltas = append(toolDeltas, part.Delta) - case ai.StreamPartTypeToolInputEnd: + case fantasy.StreamPartTypeToolInputEnd: toolInputEnd = i - case ai.StreamPartTypeToolCall: + case fantasy.StreamPartTypeToolCall: toolCall = i require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID) require.Equal(t, "test-tool", part.ToolCallName) @@ -2349,7 +2349,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2359,16 +2359,16 @@ func TestDoStream(t *testing.T) { require.NoError(t, err) // Find source part - var sourcePart *ai.StreamPart + var sourcePart *fantasy.StreamPart for _, part := range parts { - if part.Type == ai.StreamPartTypeSource { + if part.Type == fantasy.StreamPartTypeSource { sourcePart = &part break } } require.NotNil(t, sourcePart) - require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType) + require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType) require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL) require.Equal(t, "Document 1", sourcePart.Title) require.NotEmpty(t, sourcePart.ID) @@ -2388,7 +2388,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2401,9 +2401,9 @@ func TestDoStream(t *testing.T) { require.True(t, len(parts) >= 1) // Find error part - var errorPart *ai.StreamPart + var errorPart *fantasy.StreamPart for _, part := range parts { - if part.Type == ai.StreamPartTypeError { + if part.Type == fantasy.StreamPartTypeError { errorPart = &part break } @@ -2429,7 +2429,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Stream(context.Background(), ai.Call{ + _, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2477,7 +2477,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2487,9 +2487,9 @@ func TestDoStream(t *testing.T) { require.NoError(t, err) // Find finish part - var finishPart *ai.StreamPart + var finishPart *fantasy.StreamPart for _, part := range parts { - if part.Type == ai.StreamPartTypeFinish { + if part.Type == fantasy.StreamPartTypeFinish { finishPart = &part break } @@ -2527,7 +2527,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2537,9 +2537,9 @@ func TestDoStream(t *testing.T) { require.NoError(t, err) // Find finish part - var finishPart *ai.StreamPart + var finishPart *fantasy.StreamPart for _, part := range parts { - if part.Type == ai.StreamPartTypeFinish { + if part.Type == fantasy.StreamPartTypeFinish { finishPart = &part break } @@ -2570,10 +2570,10 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Stream(context.Background(), ai.Call{ + _, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - Store: ai.Opt(true), + Store: fantasy.Opt(true), }), }) @@ -2612,7 +2612,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-3.5-turbo") - _, err := model.Stream(context.Background(), ai.Call{ + _, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ Metadata: map[string]any{ @@ -2658,10 +2658,10 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("o3-mini") - _, err := model.Stream(context.Background(), ai.Call{ + _, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - ServiceTier: ai.Opt("flex"), + ServiceTier: fantasy.Opt("flex"), }), }) @@ -2700,10 +2700,10 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("gpt-4o-mini") - _, err := model.Stream(context.Background(), ai.Call{ + _, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, ProviderOptions: NewProviderOptions(&ProviderOptions{ - ServiceTier: ai.Opt("priority"), + ServiceTier: fantasy.Opt("priority"), }), }) @@ -2743,7 +2743,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("o1-preview") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2755,7 +2755,7 @@ func TestDoStream(t *testing.T) { // Find text parts var textDeltas []string for _, part := range parts { - if part.Type == ai.StreamPartTypeTextDelta { + if part.Type == fantasy.StreamPartTypeTextDelta { textDeltas = append(textDeltas, part.Delta) } } @@ -2789,7 +2789,7 @@ func TestDoStream(t *testing.T) { ) model, _ := provider.LanguageModel("o1-preview") - stream, err := model.Stream(context.Background(), ai.Call{ + stream, err := model.Stream(context.Background(), fantasy.Call{ Prompt: testPrompt, }) @@ -2799,9 +2799,9 @@ func TestDoStream(t *testing.T) { require.NoError(t, err) // Find finish part - var finishPart *ai.StreamPart + var finishPart *fantasy.StreamPart for _, part := range parts { - if part.Type == ai.StreamPartTypeFinish { + if part.Type == fantasy.StreamPartTypeFinish { finishPart = &part break } diff --git a/openai/provider_options.go b/openai/provider_options.go index 84416710fa9ba2d9b7ce28c4dc5d650fb24926cf..964761fdb648fdde9b1ffe827993ed2e0607364e 100644 --- a/openai/provider_options.go +++ b/openai/provider_options.go @@ -1,7 +1,7 @@ package openai import ( - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/openai/openai-go/v2" ) @@ -52,21 +52,21 @@ func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { return &e } -func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } -func NewProviderFileOptions(opts *ProviderFileOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewProviderFileOptions(opts *ProviderFileOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions - if err := ai.ParseOptions(data, &options); err != nil { + if err := fantasy.ParseOptions(data, &options); err != nil { return nil, err } return &options, nil diff --git a/openai/responses_language_model.go b/openai/responses_language_model.go index 46cb8c98894eba59f6f8b5094842f2339ac297d3..40f682228bcff6da97ebadd9ab0ac3bb02e7dccd 100644 --- a/openai/responses_language_model.go +++ b/openai/responses_language_model.go @@ -8,7 +8,7 @@ import ( "fmt" "strings" - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" @@ -112,8 +112,8 @@ func getResponsesModelConfig(modelID string) responsesModelConfig { } } -func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.ResponseNewParams, []ai.CallWarning) { - var warnings []ai.CallWarning +func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.ResponseNewParams, []fantasy.CallWarning) { + var warnings []fantasy.CallWarning params := &responses.ResponseNewParams{ Store: param.NewOpt(false), } @@ -121,22 +121,22 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response modelConfig := getResponsesModelConfig(o.modelID) if call.TopK != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "topK", }) } if call.PresencePenalty != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "presencePenalty", }) } if call.FrequencyPenalty != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "frequencyPenalty", }) } @@ -256,8 +256,8 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response if modelConfig.isReasoningModel { if call.Temperature != nil { params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "temperature", Details: "temperature is not supported for reasoning models", }) @@ -265,8 +265,8 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response if call.TopP != nil { params.TopP = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "topP", Details: "topP is not supported for reasoning models", }) @@ -274,16 +274,16 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response } else { if openaiOptions != nil { if openaiOptions.ReasoningEffort != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "reasoningEffort", Details: "reasoningEffort is not supported for non-reasoning models", }) } if openaiOptions.ReasoningSummary != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "reasoningSummary", Details: "reasoningSummary is not supported for non-reasoning models", }) @@ -293,8 +293,8 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response if openaiOptions != nil && openaiOptions.ServiceTier != nil { if *openaiOptions.ServiceTier == ServiceTierFlex && !modelConfig.supportsFlexProcessing { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "serviceTier", Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", }) @@ -302,8 +302,8 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response } if *openaiOptions.ServiceTier == ServiceTierPriority && !modelConfig.supportsPriorityProcessing { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedSetting, Setting: "serviceTier", Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported", }) @@ -322,26 +322,26 @@ func (o responsesLanguageModel) prepareParams(call ai.Call) (*responses.Response return params, warnings } -func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.ResponseInputParam, []ai.CallWarning) { +func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string) (responses.ResponseInputParam, []fantasy.CallWarning) { var input responses.ResponseInputParam - var warnings []ai.CallWarning + var warnings []fantasy.CallWarning for _, msg := range prompt { switch msg.Role { - case ai.MessageRoleSystem: + case fantasy.MessageRoleSystem: var systemText string for _, c := range msg.Content { - if c.GetType() != ai.ContentTypeText { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + if c.GetType() != fantasy.ContentTypeText { + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system prompt can only have text content", }) continue } - textPart, ok := ai.AsContentType[ai.TextPart](c) + textPart, ok := fantasy.AsContentType[fantasy.TextPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system prompt text part does not have the right type", }) continue @@ -352,8 +352,8 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re } if systemText == "" { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system prompt has no text parts", }) continue @@ -365,21 +365,21 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re case "developer": input = append(input, responses.ResponseInputItemParamOfMessage(systemText, responses.EasyInputMessageRoleDeveloper)) case "remove": - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "system messages are removed for this model", }) } - case ai.MessageRoleUser: + case fantasy.MessageRoleUser: var contentParts responses.ResponseInputMessageContentListParam for i, c := range msg.Content { switch c.GetType() { - case ai.ContentTypeText: - textPart, ok := ai.AsContentType[ai.TextPart](c) + case fantasy.ContentTypeText: + textPart, ok := fantasy.AsContentType[fantasy.TextPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "user message text part does not have the right type", }) continue @@ -391,11 +391,11 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re }, }) - case ai.ContentTypeFile: - filePart, ok := ai.AsContentType[ai.FilePart](c) + case fantasy.ContentTypeFile: + filePart, ok := fantasy.AsContentType[fantasy.FilePart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "user message file part does not have the right type", }) continue @@ -425,8 +425,8 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re }, }) } else { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType), }) } @@ -435,25 +435,25 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re input = append(input, responses.ResponseInputItemParamOfMessage(contentParts, responses.EasyInputMessageRoleUser)) - case ai.MessageRoleAssistant: + case fantasy.MessageRoleAssistant: for _, c := range msg.Content { switch c.GetType() { - case ai.ContentTypeText: - textPart, ok := ai.AsContentType[ai.TextPart](c) + case fantasy.ContentTypeText: + textPart, ok := fantasy.AsContentType[fantasy.TextPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "assistant message text part does not have the right type", }) continue } input = append(input, responses.ResponseInputItemParamOfMessage(textPart.Text, responses.EasyInputMessageRoleAssistant)) - case ai.ContentTypeToolCall: - toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c) + case fantasy.ContentTypeToolCall: + toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "assistant message tool call part does not have the right type", }) continue @@ -465,22 +465,22 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re inputJSON, err := json.Marshal(toolCallPart.Input) if err != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: fmt.Sprintf("failed to marshal tool call input: %v", err), }) continue } input = append(input, responses.ResponseInputItemParamOfFunctionCall(string(inputJSON), toolCallPart.ToolCallID, toolCallPart.ToolName)) - case ai.ContentTypeReasoning: + case fantasy.ContentTypeReasoning: reasoningMetadata := getReasoningMetadata(c.Options()) if reasoningMetadata == nil || reasoningMetadata.ItemID == "" { continue } if len(reasoningMetadata.Summary) == 0 && reasoningMetadata.EncryptedContent == nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "assistant message reasoning part does is empty", }) continue @@ -506,20 +506,20 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re } } - case ai.MessageRoleTool: + case fantasy.MessageRoleTool: for _, c := range msg.Content { - if c.GetType() != ai.ContentTypeToolResult { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + if c.GetType() != fantasy.ContentTypeToolResult { + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool message can only have tool result content", }) continue } - toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c) + toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool message result part does not have the right type", }) continue @@ -527,31 +527,31 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re var outputStr string switch toolResultPart.Output.GetType() { - case ai.ToolResultContentTypeText: - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output) + case fantasy.ToolResultContentTypeText: + output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool result output does not have the right type", }) continue } outputStr = output.Text - case ai.ToolResultContentTypeError: - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output) + case fantasy.ToolResultContentTypeError: + output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool result output does not have the right type", }) continue } outputStr = output.Error.Error() - case ai.ToolResultContentTypeMedia: - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](toolResultPart.Output) + case fantasy.ToolResultContentTypeMedia: + output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResultPart.Output) if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: "tool result output does not have the right type", }) continue @@ -563,8 +563,8 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re } jsonBytes, err := json.Marshal(mediaContent) if err != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeOther, Message: fmt.Sprintf("failed to marshal tool result: %v", err), }) continue @@ -580,8 +580,8 @@ func toResponsesPrompt(prompt ai.Prompt, systemMessageMode string) (responses.Re return input, warnings } -func toResponsesTools(tools []ai.Tool, toolChoice *ai.ToolChoice, options *ResponsesProviderOptions) ([]responses.ToolUnionParam, responses.ResponseNewParamsToolChoiceUnion, []ai.CallWarning) { - warnings := make([]ai.CallWarning, 0) +func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, options *ResponsesProviderOptions) ([]responses.ToolUnionParam, responses.ResponseNewParamsToolChoiceUnion, []fantasy.CallWarning) { + warnings := make([]fantasy.CallWarning, 0) var openaiTools []responses.ToolUnionParam if len(tools) == 0 { @@ -594,8 +594,8 @@ func toResponsesTools(tools []ai.Tool, toolChoice *ai.ToolChoice, options *Respo } for _, tool := range tools { - if tool.GetType() == ai.ToolTypeFunction { - ft, ok := tool.(ai.FunctionTool) + if tool.GetType() == fantasy.ToolTypeFunction { + ft, ok := tool.(fantasy.FunctionTool) if !ok { continue } @@ -611,8 +611,8 @@ func toResponsesTools(tools []ai.Tool, toolChoice *ai.ToolChoice, options *Respo continue } - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedTool, + warnings = append(warnings, fantasy.CallWarning{ + Type: fantasy.CallWarningTypeUnsupportedTool, Tool: tool, Message: "tool is not supported", }) @@ -625,15 +625,15 @@ func toResponsesTools(tools []ai.Tool, toolChoice *ai.ToolChoice, options *Respo var openaiToolChoice responses.ResponseNewParamsToolChoiceUnion switch *toolChoice { - case ai.ToolChoiceAuto: + case fantasy.ToolChoiceAuto: openaiToolChoice = responses.ResponseNewParamsToolChoiceUnion{ OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptionsAuto), } - case ai.ToolChoiceNone: + case fantasy.ToolChoiceNone: openaiToolChoice = responses.ResponseNewParamsToolChoiceUnion{ OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptionsNone), } - case ai.ToolChoiceRequired: + case fantasy.ToolChoiceRequired: openaiToolChoice = responses.ResponseNewParamsToolChoiceUnion{ OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptionsRequired), } @@ -659,7 +659,7 @@ func (o responsesLanguageModel) handleError(err error) error { v := h[len(h)-1] headers[strings.ToLower(k)] = v } - return ai.NewAPICallError( + return fantasy.NewAPICallError( apiErr.Message, apiErr.Request.URL.String(), string(requestDump), @@ -673,7 +673,7 @@ func (o responsesLanguageModel) handleError(err error) error { return err } -func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings := o.prepareParams(call) response, err := o.client.Responses.New(ctx, *params) if err != nil { @@ -684,7 +684,7 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai return nil, o.handleError(fmt.Errorf("response error: %s (code: %s)", response.Error.Message, response.Error.Code)) } - var content []ai.Content + var content []fantasy.Content hasFunctionCall := false for _, outputItem := range response.Output { @@ -692,15 +692,15 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai case "message": for _, contentPart := range outputItem.Content { if contentPart.Type == "output_text" { - content = append(content, ai.TextContent{ + content = append(content, fantasy.TextContent{ Text: contentPart.Text, }) for _, annotation := range contentPart.Annotations { switch annotation.Type { case "url_citation": - content = append(content, ai.SourceContent{ - SourceType: ai.SourceTypeURL, + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeURL, ID: uuid.NewString(), URL: annotation.URL, Title: annotation.Title, @@ -714,8 +714,8 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai if filename == "" { filename = annotation.FileID } - content = append(content, ai.SourceContent{ - SourceType: ai.SourceTypeDocument, + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeDocument, ID: uuid.NewString(), MediaType: "text/plain", Title: title, @@ -728,7 +728,7 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai case "function_call": hasFunctionCall = true - content = append(content, ai.ToolCallContent{ + content = append(content, fantasy.ToolCallContent{ ProviderExecuted: false, ToolCallID: outputItem.CallID, ToolName: outputItem.Name, @@ -760,16 +760,16 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai summaryParts = append(summaryParts, s.Text) } - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: strings.Join(summaryParts, "\n"), - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: metadata, }, }) } } - usage := ai.Usage{ + usage := fantasy.Usage{ InputTokens: response.Usage.InputTokens, OutputTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, @@ -784,47 +784,47 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, hasFunctionCall) - return &ai.Response{ + return &fantasy.Response{ Content: content, Usage: usage, FinishReason: finishReason, - ProviderMetadata: ai.ProviderMetadata{}, + ProviderMetadata: fantasy.ProviderMetadata{}, Warnings: warnings, }, nil } -func mapResponsesFinishReason(reason string, hasFunctionCall bool) ai.FinishReason { +func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.FinishReason { if hasFunctionCall { - return ai.FinishReasonToolCalls + return fantasy.FinishReasonToolCalls } switch reason { case "": - return ai.FinishReasonStop + return fantasy.FinishReasonStop case "max_tokens", "max_output_tokens": - return ai.FinishReasonLength + return fantasy.FinishReasonLength case "content_filter": - return ai.FinishReasonContentFilter + return fantasy.FinishReasonContentFilter default: - return ai.FinishReasonOther + return fantasy.FinishReasonOther } } -func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings := o.prepareParams(call) stream := o.client.Responses.NewStreaming(ctx, *params) - finishReason := ai.FinishReasonUnknown - var usage ai.Usage + finishReason := fantasy.FinishReasonUnknown + var usage fantasy.Usage ongoingToolCalls := make(map[int64]*ongoingToolCall) hasFunctionCall := false activeReasoning := make(map[string]*reasoningState) - return func(yield func(ai.StreamPart) bool) { + return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeWarnings, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeWarnings, Warnings: warnings, }) { return @@ -846,8 +846,8 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St toolName: added.Item.Name, toolCallID: added.Item.CallID, } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, ID: added.Item.CallID, ToolCallName: added.Item.Name, }) { @@ -855,8 +855,8 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St } case "message": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextStart, ID: added.Item.ID, }) { return @@ -874,10 +874,10 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St activeReasoning[added.Item.ID] = &reasoningState{ metadata: metadata, } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningStart, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, ID: added.Item.ID, - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: metadata, }, }) { @@ -894,14 +894,14 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St delete(ongoingToolCalls, done.OutputIndex) hasFunctionCall = true - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, ID: done.Item.CallID, }) { return } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, ID: done.Item.CallID, ToolCallName: done.Item.Name, ToolCallInput: done.Item.Arguments, @@ -911,8 +911,8 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St } case "message": - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, ID: done.Item.ID, }) { return @@ -921,10 +921,10 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St case "reasoning": state := activeReasoning[done.Item.ID] if state != nil { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: done.Item.ID, - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: state.metadata, }, }) { @@ -938,8 +938,8 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St delta := event.AsResponseFunctionCallArgumentsDelta() tc := ongoingToolCalls[delta.OutputIndex] if tc != nil { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, ID: tc.toolCallID, Delta: delta.Delta, }) { @@ -949,8 +949,8 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St case "response.output_text.delta": textDelta := event.AsResponseOutputTextDelta() - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, ID: textDelta.ItemID, Delta: textDelta.Delta, }) { @@ -963,11 +963,11 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St if state != nil { state.metadata.Summary = append(state.metadata.Summary, "") activeReasoning[added.ItemID] = state - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: added.ItemID, Delta: "\n", - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: state.metadata, }, }) { @@ -983,11 +983,11 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St state.metadata.Summary[textDelta.SummaryIndex] += textDelta.Delta } activeReasoning[textDelta.ItemID] = state - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: textDelta.ItemID, Delta: textDelta.Delta, - ProviderMetadata: ai.ProviderMetadata{ + ProviderMetadata: fantasy.ProviderMetadata{ Name: state.metadata, }, }) { @@ -998,7 +998,7 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St case "response.completed", "response.incomplete": completed := event.AsResponseCompleted() finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) - usage = ai.Usage{ + usage = fantasy.Usage{ InputTokens: completed.Response.Usage.InputTokens, OutputTokens: completed.Response.Usage.OutputTokens, TotalTokens: completed.Response.Usage.InputTokens + completed.Response.Usage.OutputTokens, @@ -1012,8 +1012,8 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St case "error": errorEvent := event.AsError() - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: fmt.Errorf("response error: %s (code: %s)", errorEvent.Message, errorEvent.Code), }) { return @@ -1024,22 +1024,22 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St err := stream.Err() if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, Error: o.handleError(err), }) return } - yield(ai.StreamPart{ - Type: ai.StreamPartTypeFinish, + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, Usage: usage, FinishReason: finishReason, }) }, nil } -func getReasoningMetadata(providerOptions ai.ProviderOptions) *ResponsesReasoningMetadata { +func getReasoningMetadata(providerOptions fantasy.ProviderOptions) *ResponsesReasoningMetadata { if openaiResponsesOptions, ok := providerOptions[Name]; ok { if reasoning, ok := openaiResponsesOptions.(*ResponsesReasoningMetadata); ok { return reasoning diff --git a/openai/responses_options.go b/openai/responses_options.go index bfa4abb490259c926f6870756af4c431e7bfb74c..34d81577a170011c05664d6a1b5eaef2d6efc2df 100644 --- a/openai/responses_options.go +++ b/openai/responses_options.go @@ -3,7 +3,7 @@ package openai import ( "slices" - "charm.land/fantasy/ai" + "charm.land/fantasy" ) type ResponsesReasoningMetadata struct { @@ -105,15 +105,15 @@ var responsesModelIds = append([]string{ func (*ResponsesProviderOptions) Options() {} -func NewResponsesProviderOptions(opts *ResponsesProviderOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewResponsesProviderOptions(opts *ResponsesProviderOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } func ParseResponsesOptions(data map[string]any) (*ResponsesProviderOptions, error) { var options ResponsesProviderOptions - if err := ai.ParseOptions(data, &options); err != nil { + if err := fantasy.ParseOptions(data, &options); err != nil { return nil, err } return &options, nil diff --git a/openaicompat/language_model_hooks.go b/openaicompat/language_model_hooks.go index 6b804861d43322530ff8817964448dc851e9c3ba..f2aa75fe75c0b52d91e3a08e722ab7e412a54efe 100644 --- a/openaicompat/language_model_hooks.go +++ b/openaicompat/language_model_hooks.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" openaisdk "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" @@ -13,12 +13,12 @@ import ( const reasoningStartedCtx = "reasoning_started" -func PrepareCallFunc(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { +func PrepareCallFunc(model fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) { providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) } } @@ -43,15 +43,15 @@ func PrepareCallFunc(model ai.LanguageModel, params *openaisdk.ChatCompletionNew return nil, nil } -func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []ai.Content { - var content []ai.Content +func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content { + var content []fantasy.Content reasoningData := ReasoningData{} err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData) if err != nil { return content } if reasoningData.ReasoningContent != "" { - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: reasoningData.ReasoningContent, }) } @@ -70,7 +70,7 @@ func extractReasoningContext(ctx map[string]any) bool { return b } -func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) { +func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) { if len(chunk.Choices) == 0 { return ctx, true } @@ -81,17 +81,17 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPa reasoningData := ReasoningData{} err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData) if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, - Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err), + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err), }) return ctx, false } emitEvent := func(reasoningContent string) bool { if !reasoningStarted { - shouldContinue := yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningStart, + shouldContinue := yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, ID: fmt.Sprintf("%d", inx), }) if !shouldContinue { @@ -99,8 +99,8 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPa } } - return yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + return yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", inx), Delta: reasoningContent, }) @@ -113,8 +113,8 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPa } if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) { ctx[reasoningStartedCtx] = false - return ctx, yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + return ctx, yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: fmt.Sprintf("%d", inx), }) } diff --git a/openaicompat/openaicompat.go b/openaicompat/openaicompat.go index d73a01b11c1e08620cb32f636a23f55bf36f17aa..545865ca43da4aad3e8c5962b23d53173eeef486 100644 --- a/openaicompat/openaicompat.go +++ b/openaicompat/openaicompat.go @@ -1,7 +1,7 @@ package openaicompat import ( - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" "github.com/openai/openai-go/v2/option" ) @@ -18,7 +18,7 @@ const ( type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { providerOptions := options{ openaiOptions: []openai.Option{ openai.WithName(Name), diff --git a/openaicompat/provider_options.go b/openaicompat/provider_options.go index 9db56f57f517e6d4a49caf6169bff9d807ae8c1b..71fe28c5a49d2366ff9054a6baf33cb88af3fb79 100644 --- a/openaicompat/provider_options.go +++ b/openaicompat/provider_options.go @@ -1,7 +1,7 @@ package openaicompat import ( - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" ) @@ -16,15 +16,15 @@ type ReasoningData struct { func (*ProviderOptions) Options() {} -func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions - if err := ai.ParseOptions(data, &options); err != nil { + if err := fantasy.ParseOptions(data, &options); err != nil { return nil, err } return &options, nil diff --git a/openrouter/language_model_hooks.go b/openrouter/language_model_hooks.go index de92b5d5eb56106abd39130bc4e74ecbfc25b79a..eb38cf9e05c8ece6b93e1ea831e4c128ac0014cc 100644 --- a/openrouter/language_model_hooks.go +++ b/openrouter/language_model_hooks.go @@ -5,7 +5,7 @@ import ( "fmt" "maps" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/anthropic" openaisdk "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" @@ -13,12 +13,12 @@ import ( const reasoningStartedCtx = "reasoning_started" -func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { +func languagePrepareModelCall(model fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) { providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) } } @@ -67,18 +67,18 @@ func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatComp return nil, nil } -func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content { - var content []ai.Content +func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []fantasy.Content { + var content []fantasy.Content reasoningData := ReasoningData{} err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData) if err != nil { return content } for _, detail := range reasoningData.ReasoningDetails { - var metadata ai.ProviderMetadata + var metadata fantasy.ProviderMetadata if detail.Signature != "" { - metadata = ai.ProviderMetadata{ + metadata = fantasy.ProviderMetadata{ Name: &ReasoningMetadata{ Signature: detail.Signature, }, @@ -89,17 +89,17 @@ func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Conte } switch detail.Type { case "reasoning.text": - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: detail.Text, ProviderMetadata: metadata, }) case "reasoning.summary": - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: detail.Summary, ProviderMetadata: metadata, }) case "reasoning.encrypted": - content = append(content, ai.ReasoningContent{ + content = append(content, fantasy.ReasoningContent{ Text: "[REDACTED]", ProviderMetadata: metadata, }) @@ -120,7 +120,7 @@ func extractReasoningContext(ctx map[string]any) bool { return b } -func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) { +func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) { if len(chunk.Choices) == 0 { return ctx, true } @@ -131,17 +131,17 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai reasoningData := ReasoningData{} err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData) if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, - Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err), + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err), }) return ctx, false } emitEvent := func(reasoningContent string, signature string) bool { if !reasoningStarted { - shouldContinue := yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningStart, + shouldContinue := yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, ID: fmt.Sprintf("%d", inx), }) if !shouldContinue { @@ -149,10 +149,10 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai } } - var metadata ai.ProviderMetadata + var metadata fantasy.ProviderMetadata if signature != "" { - metadata = ai.ProviderMetadata{ + metadata = fantasy.ProviderMetadata{ Name: &ReasoningMetadata{ Signature: signature, }, @@ -162,8 +162,8 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai } } - return yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningDelta, + return yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", inx), Delta: reasoningContent, ProviderMetadata: metadata, @@ -188,8 +188,8 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai } if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) { ctx[reasoningStartedCtx] = false - return ctx, yield(ai.StreamPart{ - Type: ai.StreamPartTypeReasoningEnd, + return ctx, yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, ID: fmt.Sprintf("%d", inx), }) } @@ -197,9 +197,9 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai return ctx, true } -func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) { +func languageModelUsage(response openaisdk.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData) { if len(response.Choices) == 0 { - return ai.Usage{}, nil + return fantasy.Usage{}, nil } openrouterUsage := UsageAccounting{} usage := response.Usage @@ -220,7 +220,7 @@ func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.Provide Usage: openrouterUsage, } - return ai.Usage{ + return fantasy.Usage{ InputTokens: usage.PromptTokens, OutputTokens: usage.CompletionTokens, TotalTokens: usage.TotalTokens, @@ -229,10 +229,10 @@ func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.Provide }, providerMetadata } -func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) { +func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata) { usage := chunk.Usage if usage.TotalTokens == 0 { - return ai.Usage{}, nil + return fantasy.Usage{}, nil } streamProviderMetadata := &ProviderMetadata{} @@ -255,7 +255,7 @@ func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string] // we do this here because the acc does not add prompt details completionTokenDetails := usage.CompletionTokensDetails promptTokenDetails := usage.PromptTokensDetails - aiUsage := ai.Usage{ + aiUsage := fantasy.Usage{ InputTokens: usage.PromptTokens, OutputTokens: usage.CompletionTokens, TotalTokens: usage.TotalTokens, @@ -263,7 +263,7 @@ func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string] CacheReadTokens: promptTokenDetails.CachedTokens, } - return aiUsage, ai.ProviderMetadata{ + return aiUsage, fantasy.ProviderMetadata{ Name: streamProviderMetadata, } } diff --git a/openrouter/openrouter.go b/openrouter/openrouter.go index d9e800fbb01b40d45f117ffdb8a28b9696601675..c0a393746fb55b3ad08b0abc620c4ba744ffab57 100644 --- a/openrouter/openrouter.go +++ b/openrouter/openrouter.go @@ -3,7 +3,7 @@ package openrouter import ( "encoding/json" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" "github.com/openai/openai-go/v2/option" ) @@ -20,7 +20,7 @@ const ( type Option = func(*options) -func New(opts ...Option) ai.Provider { +func New(opts ...Option) fantasy.Provider { providerOptions := options{ openaiOptions: []openai.Option{ openai.WithName(Name), diff --git a/openrouter/provider_options.go b/openrouter/provider_options.go index 0629a7314817971e6af9c8722e463f1ce471504a..e615ff861771f5450cb280f3a9137ddf9aefc88c 100644 --- a/openrouter/provider_options.go +++ b/openrouter/provider_options.go @@ -1,7 +1,7 @@ package openrouter import ( - "charm.land/fantasy/ai" + "charm.land/fantasy" ) type ReasoningEffort string @@ -116,15 +116,15 @@ func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { return &e } -func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { - return ai.ProviderOptions{ +func NewProviderOptions(opts *ProviderOptions) fantasy.ProviderOptions { + return fantasy.ProviderOptions{ Name: opts, } } func ParseOptions(data map[string]any) (*ProviderOptions, error) { var options ProviderOptions - if err := ai.ParseOptions(data, &options); err != nil { + if err := fantasy.ParseOptions(data, &options); err != nil { return nil, err } return &options, nil diff --git a/ai/provider.go b/provider.go similarity index 85% rename from ai/provider.go rename to provider.go index dd944873706802445e823096138c572042315869..11843f52be7ab09800f270662f079a9a820f09a5 100644 --- a/ai/provider.go +++ b/provider.go @@ -1,4 +1,4 @@ -package ai +package fantasy type Provider interface { Name() string diff --git a/providertests/anthropic_test.go b/providertests/anthropic_test.go index cd2f0f35abbdea430711d41d2442dec6d7d23ed2..09099b1187f679e81ee21f8b4aa396afade5931a 100644 --- a/providertests/anthropic_test.go +++ b/providertests/anthropic_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/anthropic" "github.com/stretchr/testify/require" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" @@ -24,7 +24,7 @@ func TestAnthropicCommon(t *testing.T) { } func TestAnthropicThinking(t *testing.T) { - opts := ai.ProviderOptions{ + opts := fantasy.ProviderOptions{ anthropic.Name: &anthropic.ProviderOptions{ Thinking: &anthropic.ThinkingProviderOption{ BudgetTokens: 4000, @@ -41,16 +41,16 @@ func TestAnthropicThinking(t *testing.T) { testThinking(t, pairs, testAnthropicThinking) } -func testAnthropicThinking(t *testing.T, result *ai.AgentResult) { +func testAnthropicThinking(t *testing.T, result *fantasy.AgentResult) { reasoningContentCount := 0 signaturesCount := 0 // Test if we got the signature for _, step := range result.Steps { for _, msg := range step.Messages { for _, content := range msg.Content { - if content.GetType() == ai.ContentTypeReasoning { + if content.GetType() == fantasy.ContentTypeReasoning { reasoningContentCount += 1 - reasoningContent, ok := ai.AsContentType[ai.ReasoningPart](content) + reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningPart](content) if !ok { continue } @@ -78,7 +78,7 @@ func testAnthropicThinking(t *testing.T, result *ai.AgentResult) { } func anthropicBuilder(model string) builderFunc { - return func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := anthropic.New( anthropic.WithAPIKey(os.Getenv("FANTASY_ANTHROPIC_API_KEY")), anthropic.WithHTTPClient(&http.Client{Transport: r}), diff --git a/providertests/azure_test.go b/providertests/azure_test.go index e6683f9829d562b351da2fcfad26796d1504a051..f3cb55ec4ec91fee4bb4c7b81a5baf04aed59bff 100644 --- a/providertests/azure_test.go +++ b/providertests/azure_test.go @@ -6,7 +6,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/azure" "charm.land/fantasy/openai" "github.com/stretchr/testify/require" @@ -24,7 +24,7 @@ func TestAzureCommon(t *testing.T) { } func TestAzureThinking(t *testing.T) { - opts := ai.ProviderOptions{ + opts := fantasy.ProviderOptions{ openai.Name: &openai.ProviderOptions{ ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortLow), }, @@ -35,11 +35,11 @@ func TestAzureThinking(t *testing.T) { }, testAzureThinking) } -func testAzureThinking(t *testing.T, result *ai.AgentResult) { +func testAzureThinking(t *testing.T, result *fantasy.AgentResult) { require.Greater(t, result.Response.Usage.ReasoningTokens, int64(0), "expected reasoning tokens, got none") } -func builderAzureO4Mini(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderAzureO4Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := azure.New( azure.WithBaseURL(cmp.Or(os.Getenv("FANTASY_AZURE_BASE_URL"), defaultBaseURL)), azure.WithAPIKey(cmp.Or(os.Getenv("FANTASY_AZURE_API_KEY"), "(missing)")), @@ -48,7 +48,7 @@ func builderAzureO4Mini(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("o4-mini") } -func builderAzureGpt5Mini(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderAzureGpt5Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := azure.New( azure.WithBaseURL(cmp.Or(os.Getenv("FANTASY_AZURE_BASE_URL"), defaultBaseURL)), azure.WithAPIKey(cmp.Or(os.Getenv("FANTASY_AZURE_API_KEY"), "(missing)")), @@ -57,7 +57,7 @@ func builderAzureGpt5Mini(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("gpt-5-mini") } -func builderAzureGrok3Mini(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderAzureGrok3Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := azure.New( azure.WithBaseURL(cmp.Or(os.Getenv("FANTASY_AZURE_BASE_URL"), defaultBaseURL)), azure.WithAPIKey(cmp.Or(os.Getenv("FANTASY_AZURE_API_KEY"), "(missing)")), diff --git a/providertests/bedrock_test.go b/providertests/bedrock_test.go index 24cb211e9fa01fd01cfedd64afc5201f0c188a25..e70c6655cbbd240fe176f79dd6add4a33d6742a9 100644 --- a/providertests/bedrock_test.go +++ b/providertests/bedrock_test.go @@ -4,7 +4,7 @@ import ( "net/http" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/bedrock" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" ) @@ -17,7 +17,7 @@ func TestBedrockCommon(t *testing.T) { }) } -func builderBedrockClaude3Sonnet(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderBedrockClaude3Sonnet(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := bedrock.New( bedrock.WithHTTPClient(&http.Client{Transport: r}), bedrock.WithSkipAuth(!r.IsRecording()), @@ -25,7 +25,7 @@ func builderBedrockClaude3Sonnet(r *recorder.Recorder) (ai.LanguageModel, error) return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0") } -func builderBedrockClaude3Opus(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderBedrockClaude3Opus(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := bedrock.New( bedrock.WithHTTPClient(&http.Client{Transport: r}), bedrock.WithSkipAuth(!r.IsRecording()), @@ -33,7 +33,7 @@ func builderBedrockClaude3Opus(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("us.anthropic.claude-3-opus-20240229-v1:0") } -func builderBedrockClaude3Haiku(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderBedrockClaude3Haiku(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := bedrock.New( bedrock.WithHTTPClient(&http.Client{Transport: r}), bedrock.WithSkipAuth(!r.IsRecording()), diff --git a/providertests/common_test.go b/providertests/common_test.go index 3ae5e20c7544f83523678f2577b9dfeeb6e0cd85..b840882a2cfdccf1380b61ee9439d7aefd9ce268 100644 --- a/providertests/common_test.go +++ b/providertests/common_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "github.com/joho/godotenv" "github.com/stretchr/testify/require" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" @@ -27,12 +27,12 @@ type testModel struct { reasoning bool } -type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error) +type builderFunc func(r *recorder.Recorder) (fantasy.LanguageModel, error) type builderPair struct { name string builder builderFunc - providerOptions ai.ProviderOptions + providerOptions fantasy.ProviderOptions } func testCommon(t *testing.T, pairs []builderPair) { @@ -46,7 +46,7 @@ func testCommon(t *testing.T, pairs []builderPair) { } func testSimple(t *testing.T, pair builderPair) { - checkResult := func(t *testing.T, result *ai.AgentResult) { + checkResult := func(t *testing.T, result *fantasy.AgentResult) { options := []string{"Oi", "oi", "Olรก", "olรก"} got := result.Response.Content.Text() require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options) @@ -58,14 +58,14 @@ func testSimple(t *testing.T, pair builderPair) { languageModel, err := pair.builder(r) require.NoError(t, err, "failed to build language model") - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithSystemPrompt("You are a helpful assistant"), ) - result, err := agent.Generate(t.Context(), ai.AgentCall{ + result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: "Say hi in Portuguese", ProviderOptions: pair.providerOptions, - MaxOutputTokens: ai.Opt(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), }) require.NoError(t, err, "failed to generate") checkResult(t, result) @@ -76,14 +76,14 @@ func testSimple(t *testing.T, pair builderPair) { languageModel, err := pair.builder(r) require.NoError(t, err, "failed to build language model") - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithSystemPrompt("You are a helpful assistant"), ) - result, err := agent.Stream(t.Context(), ai.AgentStreamCall{ + result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{ Prompt: "Say hi in Portuguese", ProviderOptions: pair.providerOptions, - MaxOutputTokens: ai.Opt(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), }) require.NoError(t, err, "failed to generate") checkResult(t, result) @@ -95,20 +95,20 @@ func testTool(t *testing.T, pair builderPair) { Location string `json:"location" description:"the city"` } - weatherTool := ai.NewAgentTool( + weatherTool := fantasy.NewAgentTool( "weather", "Get weather information for a location", - func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) { - return ai.NewTextResponse("40 C"), nil + func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("40 C"), nil }, ) - checkResult := func(t *testing.T, result *ai.AgentResult) { + checkResult := func(t *testing.T, result *fantasy.AgentResult) { require.GreaterOrEqual(t, len(result.Steps), 2) - var toolCalls []ai.ToolCallContent + var toolCalls []fantasy.ToolCallContent for _, content := range result.Steps[0].Content { - if content.GetType() == ai.ContentTypeToolCall { - toolCalls = append(toolCalls, content.(ai.ToolCallContent)) + if content.GetType() == fantasy.ContentTypeToolCall { + toolCalls = append(toolCalls, content.(fantasy.ToolCallContent)) } } for _, tc := range toolCalls { @@ -129,15 +129,15 @@ func testTool(t *testing.T, pair builderPair) { languageModel, err := pair.builder(r) require.NoError(t, err, "failed to build language model") - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant"), - ai.WithTools(weatherTool), + fantasy.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithTools(weatherTool), ) - result, err := agent.Generate(t.Context(), ai.AgentCall{ + result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: "What's the weather in Florence,Italy?", ProviderOptions: pair.providerOptions, - MaxOutputTokens: ai.Opt(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), }) require.NoError(t, err, "failed to generate") checkResult(t, result) @@ -148,15 +148,15 @@ func testTool(t *testing.T, pair builderPair) { languageModel, err := pair.builder(r) require.NoError(t, err, "failed to build language model") - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant"), - ai.WithTools(weatherTool), + fantasy.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithTools(weatherTool), ) - result, err := agent.Stream(t.Context(), ai.AgentStreamCall{ + result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{ Prompt: "What's the weather in Florence,Italy?", ProviderOptions: pair.providerOptions, - MaxOutputTokens: ai.Opt(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), }) require.NoError(t, err, "failed to generate") checkResult(t, result) @@ -183,29 +183,29 @@ func testMultiTool(t *testing.T, pair builderPair) { B int `json:"b" description:"second number"` } - addTool := ai.NewAgentTool( + addTool := fantasy.NewAgentTool( "add", "Add two numbers", - func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) { + func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { result := input.A + input.B - return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil + return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil }, ) - multiplyTool := ai.NewAgentTool( + multiplyTool := fantasy.NewAgentTool( "multiply", "Multiply two numbers", - func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) { + func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { result := input.A * input.B - return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil + return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil }, ) - checkResult := func(t *testing.T, result *ai.AgentResult) { + checkResult := func(t *testing.T, result *fantasy.AgentResult) { require.Len(t, result.Steps, 2) - var toolCalls []ai.ToolCallContent + var toolCalls []fantasy.ToolCallContent for _, content := range result.Steps[0].Content { - if content.GetType() == ai.ContentTypeToolCall { - toolCalls = append(toolCalls, content.(ai.ToolCallContent)) + if content.GetType() == fantasy.ContentTypeToolCall { + toolCalls = append(toolCalls, content.(fantasy.ToolCallContent)) } } for _, tc := range toolCalls { @@ -224,16 +224,16 @@ func testMultiTool(t *testing.T, pair builderPair) { languageModel, err := pair.builder(r) require.NoError(t, err, "failed to build language model") - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."), - ai.WithTools(addTool), - ai.WithTools(multiplyTool), + fantasy.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."), + fantasy.WithTools(addTool), + fantasy.WithTools(multiplyTool), ) - result, err := agent.Generate(t.Context(), ai.AgentCall{ + result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: "Add and multiply the number 2 and 3", ProviderOptions: pair.providerOptions, - MaxOutputTokens: ai.Opt(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), }) require.NoError(t, err, "failed to generate") checkResult(t, result) @@ -244,23 +244,23 @@ func testMultiTool(t *testing.T, pair builderPair) { languageModel, err := pair.builder(r) require.NoError(t, err, "failed to build language model") - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."), - ai.WithTools(addTool), - ai.WithTools(multiplyTool), + fantasy.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."), + fantasy.WithTools(addTool), + fantasy.WithTools(multiplyTool), ) - result, err := agent.Stream(t.Context(), ai.AgentStreamCall{ + result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{ Prompt: "Add and multiply the number 2 and 3", ProviderOptions: pair.providerOptions, - MaxOutputTokens: ai.Opt(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), }) require.NoError(t, err, "failed to generate") checkResult(t, result) }) } -func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) { +func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *fantasy.AgentResult)) { for _, pair := range pairs { t.Run(pair.name, func(t *testing.T) { t.Run("thinking", func(t *testing.T) { @@ -273,20 +273,20 @@ func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T Location string `json:"location" description:"the city"` } - weatherTool := ai.NewAgentTool( + weatherTool := fantasy.NewAgentTool( "weather", "Get weather information for a location", - func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) { - return ai.NewTextResponse("40 C"), nil + func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("40 C"), nil }, ) - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant"), - ai.WithTools(weatherTool), + fantasy.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithTools(weatherTool), ) - result, err := agent.Generate(t.Context(), ai.AgentCall{ + result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: "What's the weather in Florence, Italy?", ProviderOptions: pair.providerOptions, }) @@ -309,20 +309,20 @@ func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T Location string `json:"location" description:"the city"` } - weatherTool := ai.NewAgentTool( + weatherTool := fantasy.NewAgentTool( "weather", "Get weather information for a location", - func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) { - return ai.NewTextResponse("40 C"), nil + func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("40 C"), nil }, ) - agent := ai.NewAgent( + agent := fantasy.NewAgent( languageModel, - ai.WithSystemPrompt("You are a helpful assistant"), - ai.WithTools(weatherTool), + fantasy.WithSystemPrompt("You are a helpful assistant"), + fantasy.WithTools(weatherTool), ) - result, err := agent.Stream(t.Context(), ai.AgentStreamCall{ + result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{ Prompt: "What's the weather in Florence, Italy?", ProviderOptions: pair.providerOptions, }) diff --git a/providertests/google_test.go b/providertests/google_test.go index 2dc873c08aaa8cba54c090e9fbaa53c9e01ae065..644820d2b42ab488fb3d660d87986adffb32a117 100644 --- a/providertests/google_test.go +++ b/providertests/google_test.go @@ -6,7 +6,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/google" "github.com/stretchr/testify/require" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" @@ -35,11 +35,11 @@ func TestGoogleCommon(t *testing.T) { } func TestGoogleThinking(t *testing.T) { - opts := ai.ProviderOptions{ + opts := fantasy.ProviderOptions{ google.Name: &google.ProviderOptions{ ThinkingConfig: &google.ThinkingConfig{ - ThinkingBudget: ai.Opt(int64(100)), - IncludeThoughts: ai.Opt(true), + ThinkingBudget: fantasy.Opt(int64(100)), + IncludeThoughts: fantasy.Opt(true), }, }, } @@ -54,13 +54,13 @@ func TestGoogleThinking(t *testing.T) { testThinking(t, pairs, testGoogleThinking) } -func testGoogleThinking(t *testing.T, result *ai.AgentResult) { +func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) { reasoningContentCount := 0 // Test if we got the signature for _, step := range result.Steps { for _, msg := range step.Messages { for _, content := range msg.Content { - if content.GetType() == ai.ContentTypeReasoning { + if content.GetType() == fantasy.ContentTypeReasoning { reasoningContentCount += 1 } } @@ -70,7 +70,7 @@ func testGoogleThinking(t *testing.T, result *ai.AgentResult) { } func geminiBuilder(model string) builderFunc { - return func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := google.New( google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")), google.WithHTTPClient(&http.Client{Transport: r}), @@ -80,7 +80,7 @@ func geminiBuilder(model string) builderFunc { } func vertexBuilder(model string) builderFunc { - return func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := google.New( google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")), google.WithHTTPClient(&http.Client{Transport: r}), diff --git a/providertests/openai_responses_test.go b/providertests/openai_responses_test.go index beb98e9e2a64ceaac824af7a6744edfcf518c8f7..60af99c08499da8992449bcbefc3a6bfac1e4d52 100644 --- a/providertests/openai_responses_test.go +++ b/providertests/openai_responses_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" "github.com/stretchr/testify/require" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" @@ -20,7 +20,7 @@ func TestOpenAIResponsesCommon(t *testing.T) { } func openAIReasoningBuilder(model string) builderFunc { - return func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openai.New( openai.WithAPIKey(os.Getenv("FANTASY_OPENAI_API_KEY")), openai.WithHTTPClient(&http.Client{Transport: r}), @@ -31,13 +31,13 @@ func openAIReasoningBuilder(model string) builderFunc { } func TestOpenAIResponsesWithSummaryThinking(t *testing.T) { - opts := ai.ProviderOptions{ + opts := fantasy.ProviderOptions{ openai.Name: &openai.ResponsesProviderOptions{ Include: []openai.IncludeType{ openai.IncludeReasoningEncryptedContent, }, ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortHigh), - ReasoningSummary: ai.Opt("auto"), + ReasoningSummary: fantasy.Opt("auto"), }, } var pairs []builderPair @@ -50,16 +50,16 @@ func TestOpenAIResponsesWithSummaryThinking(t *testing.T) { testThinking(t, pairs, testOpenAIResponsesThinkingWithSummaryThinking) } -func testOpenAIResponsesThinkingWithSummaryThinking(t *testing.T, result *ai.AgentResult) { +func testOpenAIResponsesThinkingWithSummaryThinking(t *testing.T, result *fantasy.AgentResult) { reasoningContentCount := 0 encryptedData := 0 // Test if we got the signature for _, step := range result.Steps { for _, msg := range step.Messages { for _, content := range msg.Content { - if content.GetType() == ai.ContentTypeReasoning { + if content.GetType() == fantasy.ContentTypeReasoning { reasoningContentCount += 1 - reasoningContent, ok := ai.AsContentType[ai.ReasoningPart](content) + reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningPart](content) if !ok { continue } diff --git a/providertests/openai_test.go b/providertests/openai_test.go index f9a0ef2a03580d2a79a74c6a42edaf64d1286da1..15e8240fc84668b1eae80c0954da1d3e354bf483 100644 --- a/providertests/openai_test.go +++ b/providertests/openai_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" ) @@ -26,7 +26,7 @@ func TestOpenAICommon(t *testing.T) { } func openAIBuilder(model string) builderFunc { - return func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openai.New( openai.WithAPIKey(os.Getenv("FANTASY_OPENAI_API_KEY")), openai.WithHTTPClient(&http.Client{Transport: r}), diff --git a/providertests/openaicompat_test.go b/providertests/openaicompat_test.go index f3b15ea74d5dbbf97a12bda3c588617a29b5da66..14a9e2004a70075bcef1a144ac8773ad179ee69c 100644 --- a/providertests/openaicompat_test.go +++ b/providertests/openaicompat_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openai" "charm.land/fantasy/openaicompat" "github.com/stretchr/testify/require" @@ -22,7 +22,7 @@ func TestOpenAICompatibleCommon(t *testing.T) { } func TestOpenAICompatibleThinking(t *testing.T) { - opts := ai.ProviderOptions{ + opts := fantasy.ProviderOptions{ openaicompat.Name: &openaicompat.ProviderOptions{ ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortHigh), }, @@ -33,12 +33,12 @@ func TestOpenAICompatibleThinking(t *testing.T) { }, testOpenAICompatThinking) } -func testOpenAICompatThinking(t *testing.T, result *ai.AgentResult) { +func testOpenAICompatThinking(t *testing.T, result *fantasy.AgentResult) { reasoningContentCount := 0 for _, step := range result.Steps { for _, msg := range step.Messages { for _, content := range msg.Content { - if content.GetType() == ai.ContentTypeReasoning { + if content.GetType() == fantasy.ContentTypeReasoning { reasoningContentCount += 1 } } @@ -47,7 +47,7 @@ func testOpenAICompatThinking(t *testing.T, result *ai.AgentResult) { require.Greater(t, reasoningContentCount, 0, "expected reasoning content, got none") } -func builderXAIGrokCodeFast(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderXAIGrokCodeFast(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openaicompat.New( openaicompat.WithBaseURL("https://api.x.ai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")), @@ -56,7 +56,7 @@ func builderXAIGrokCodeFast(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("grok-code-fast-1") } -func builderXAIGrok4Fast(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderXAIGrok4Fast(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openaicompat.New( openaicompat.WithBaseURL("https://api.x.ai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")), @@ -65,7 +65,7 @@ func builderXAIGrok4Fast(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("grok-4-fast") } -func builderXAIGrok3Mini(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderXAIGrok3Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openaicompat.New( openaicompat.WithBaseURL("https://api.x.ai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")), @@ -74,7 +74,7 @@ func builderXAIGrok3Mini(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("grok-3-mini") } -func builderZAIGLM45(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderZAIGLM45(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openaicompat.New( openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"), openaicompat.WithAPIKey(os.Getenv("FANTASY_ZAI_API_KEY")), @@ -83,7 +83,7 @@ func builderZAIGLM45(r *recorder.Recorder) (ai.LanguageModel, error) { return provider.LanguageModel("glm-4.5") } -func builderGroq(r *recorder.Recorder) (ai.LanguageModel, error) { +func builderGroq(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openaicompat.New( openaicompat.WithBaseURL("https://api.groq.com/openai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_GROQ_API_KEY")), diff --git a/providertests/openrouter_test.go b/providertests/openrouter_test.go index 244afcb34da402b78e91ef34d45d2c9b69d9483c..d9001e18930d57a5dc95c595252b39fc1e6d19b8 100644 --- a/providertests/openrouter_test.go +++ b/providertests/openrouter_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "charm.land/fantasy/ai" + "charm.land/fantasy" "charm.land/fantasy/openrouter" "github.com/stretchr/testify/require" "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" @@ -31,7 +31,7 @@ func TestOpenRouterCommon(t *testing.T) { } func TestOpenRouterThinking(t *testing.T) { - opts := ai.ProviderOptions{ + opts := fantasy.ProviderOptions{ openrouter.Name: &openrouter.ProviderOptions{ Reasoning: &openrouter.ReasoningOptions{ Effort: openrouter.ReasoningEffortOption(openrouter.ReasoningEffortMedium), @@ -54,16 +54,16 @@ func TestOpenRouterThinking(t *testing.T) { }, testOpenrouterThinkingWithSignature) } -func testOpenrouterThinkingWithSignature(t *testing.T, result *ai.AgentResult) { +func testOpenrouterThinkingWithSignature(t *testing.T, result *fantasy.AgentResult) { reasoningContentCount := 0 signaturesCount := 0 // Test if we got the signature for _, step := range result.Steps { for _, msg := range step.Messages { for _, content := range msg.Content { - if content.GetType() == ai.ContentTypeReasoning { + if content.GetType() == fantasy.ContentTypeReasoning { reasoningContentCount += 1 - reasoningContent, ok := ai.AsContentType[ai.ReasoningPart](content) + reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningPart](content) if !ok { continue } @@ -92,12 +92,12 @@ func testOpenrouterThinkingWithSignature(t *testing.T, result *ai.AgentResult) { testAnthropicThinking(t, result) } -func testOpenrouterThinking(t *testing.T, result *ai.AgentResult) { +func testOpenrouterThinking(t *testing.T, result *fantasy.AgentResult) { reasoningContentCount := 0 for _, step := range result.Steps { for _, msg := range step.Messages { for _, content := range msg.Content { - if content.GetType() == ai.ContentTypeReasoning { + if content.GetType() == fantasy.ContentTypeReasoning { reasoningContentCount += 1 } } @@ -107,7 +107,7 @@ func testOpenrouterThinking(t *testing.T, result *ai.AgentResult) { } func openrouterBuilder(model string) builderFunc { - return func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (fantasy.LanguageModel, error) { provider := openrouter.New( openrouter.WithAPIKey(os.Getenv("FANTASY_OPENROUTER_API_KEY")), openrouter.WithHTTPClient(&http.Client{Transport: r}), diff --git a/ai/retry.go b/retry.go similarity index 99% rename from ai/retry.go rename to retry.go index f96389d58a1e88fd4a91705578e6bfe49df13753..9201b99cf82951c487d4d0bcb0f694b2f89b4772 100644 --- a/ai/retry.go +++ b/retry.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "context" diff --git a/ai/tool.go b/tool.go similarity index 99% rename from ai/tool.go rename to tool.go index 46e178ec45ea3b1cf630f24734346e52f7f6a80d..568608ed0e83225603eb7089a7c0c9c24752f49a 100644 --- a/ai/tool.go +++ b/tool.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "context" diff --git a/ai/tool_test.go b/tool_test.go similarity index 99% rename from ai/tool_test.go rename to tool_test.go index 8539c38f1e9f1c38b529c9db62988c6fd005f144..ed47228bb159b4555e0c4f282f3339a491ac38db 100644 --- a/ai/tool_test.go +++ b/tool_test.go @@ -1,4 +1,4 @@ -package ai +package fantasy import ( "context" diff --git a/ai/util.go b/util.go similarity index 95% rename from ai/util.go rename to util.go index 2d11a0c10bd4e8c5ad3806073a51e924196ad439..5f8a4a7cbc9c5fd68853a59de2c2acd6e245ae87 100644 --- a/ai/util.go +++ b/util.go @@ -1,4 +1,4 @@ -package ai +package fantasy import "github.com/go-viper/mapstructure/v2"