diff --git a/internal/config/config.go b/internal/config/config.go index f4467f71a86d027298c8adc5dabc872a02710d0c..1c20188a12a3955fde6b6eeed9f12ea39288e328 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -78,6 +78,8 @@ type ProviderConfig struct { // Extra headers to send with each request to the provider. ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // Extra body + ExtraBody map[string]any `json:"extra_body,omitempty"` // Used to pass extra parameters to the provider. ExtraParams map[string]string `json:"-"` diff --git a/internal/config/load.go b/internal/config/load.go index f481be240e9d82520cef6c9f75210d8cbd1a0776..4e877ec15752d12fdbaf35bac241ca6f57415ead 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -171,6 +171,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Type: p.Type, Disable: config.Disable, ExtraHeaders: config.ExtraHeaders, + ExtraBody: config.ExtraBody, ExtraParams: make(map[string]string), Models: p.Models, } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index a65f0b752367ca7b2e62f9dd263a7dd6e5ce7a53..e8c980aeb3cc35a51071227093fac8e8cd3a4b0c 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -45,6 +45,12 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } + for _, header := range opts.extraHeaders { + anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(header, opts.extraHeaders[header])) + } + for key, value := range opts.extraBody { + anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value)) + } return anthropic.NewClient(anthropicClientOptions...) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 4afac2c70809d6c98e0aa35022c296c3d95ef05e..f55914520774e2fcf5e6283e22365f4ce3621dc1 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -44,10 +44,12 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { } } - if opts.extraHeaders != nil { - for key, value := range opts.extraHeaders { - openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) - } + for key, value := range opts.extraHeaders { + openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) + } + + for extraKey, extraValue := range opts.extraBody { + openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue)) } return openai.NewClient(openaiClientOptions...) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 12dd09392942b0c00e7caa975deefffa994b47b8..412093334169b4c0d59fdd4f3f72b1e427651307 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -70,6 +70,7 @@ type providerClientOptions struct { systemMessage string maxTokens int64 extraHeaders map[string]string + extraBody map[string]any extraParams map[string]string } @@ -147,6 +148,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi config: cfg, apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, + extraBody: cfg.ExtraBody, model: func(tp config.SelectedModelType) provider.Model { return *config.Get().GetModelByType(tp) },