diff --git a/openai/language_model.go b/openai/language_model.go index 188d81ee82320480167c8bf5de3e3c35140d636b..bc9c42bd1e9c94c36095865fc6cb0b0948006b97 100644 --- a/openai/language_model.go +++ b/openai/language_model.go @@ -87,12 +87,12 @@ func newLanguageModel(modelID string, provider string, client openai.Client, opt modelID: modelID, provider: provider, client: client, - generateIDFunc: defaultGenerateID, - prepareCallFunc: defaultPrepareLanguageModelCall, - mapFinishReasonFunc: defaultMapFinishReason, - usageFunc: defaultUsage, - streamUsageFunc: defaultStreamUsage, - streamProviderMetadataFunc: defaultStreamProviderMetadataFunc, + generateIDFunc: DefaultGenerateID, + prepareCallFunc: DefaultPrepareCallFunc, + mapFinishReasonFunc: DefaultMapFinishReasonFunc, + usageFunc: DefaultUsageFunc, + streamUsageFunc: DefaultStreamUsageFunc, + streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc, } for _, o := range opts { @@ -303,7 +303,7 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response return &ai.Response{ Content: content, Usage: usage, - FinishReason: defaultMapFinishReason(choice), + FinishReason: DefaultMapFinishReasonFunc(choice), ProviderMetadata: ai.ProviderMetadata{ Name: providerMetadata, }, diff --git a/openai/language_model_hooks.go b/openai/language_model_hooks.go index ceb2a57c71afb9bcecbf93acb66b87aa38fcdb63..2a47615080cc0801d4973e688983945147827aa8 100644 --- a/openai/language_model_hooks.go +++ b/openai/language_model_hooks.go @@ -21,11 +21,11 @@ type ( LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata ) -func defaultGenerateID() string { +func DefaultGenerateID() string { return uuid.NewString() } -func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { +func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { if call.ProviderOptions == nil { return nil, nil } @@ -162,7 +162,7 @@ func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.Chat return warnings, nil } -func defaultMapFinishReason(choice openai.ChatCompletionChoice) ai.FinishReason { +func DefaultMapFinishReasonFunc(choice openai.ChatCompletionChoice) ai.FinishReason { finishReason := choice.FinishReason switch finishReason { case "stop": @@ -178,7 +178,7 @@ func defaultMapFinishReason(choice openai.ChatCompletionChoice) ai.FinishReason } } -func defaultUsage(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) { +func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) { if len(response.Choices) == 0 { return ai.Usage{}, nil } @@ -211,7 +211,7 @@ func defaultUsage(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsD }, providerMetadata } -func defaultStreamUsage(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) { +func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) { if chunk.Usage.TotalTokens == 0 { return ai.Usage{}, nil } @@ -250,7 +250,7 @@ func defaultStreamUsage(chunk openai.ChatCompletionChunk, ctx map[string]any, me } } -func defaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata { +func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata { streamProviderMetadata, ok := metadata[Name] if !ok { streamProviderMetadata = &ProviderMetadata{}