diff --git a/internal/config/config.go b/internal/config/config.go index d63a34f73d5210c2542be8a598ef38cb06339bd9..1d1f38ac303fe14c221cd0e186ca493ac76e067a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -75,6 +75,8 @@ type ProviderConfig struct { // Extra headers to send with each request to the provider. ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // Extra body + ExtraBody map[string]any `json:"extra_body,omitempty"` // Used to pass extra parameters to the provider. ExtraParams map[string]string `json:"-"` diff --git a/internal/config/load.go b/internal/config/load.go index e056847aeeda476e819439384a16e0e237b067e1..0dd0a5fda518a7ef34b0da42dc6b23d37d4430ca 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -141,6 +141,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Type: p.Type, Disable: config.Disable, ExtraHeaders: config.ExtraHeaders, + ExtraBody: config.ExtraBody, ExtraParams: make(map[string]string), Models: p.Models, } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index a65f0b752367ca7b2e62f9dd263a7dd6e5ce7a53..e8c980aeb3cc35a51071227093fac8e8cd3a4b0c 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -45,6 +45,12 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } + for _, header := range opts.extraHeaders { + anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(header, opts.extraHeaders[header])) + } + for key, value := range opts.extraBody { + anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value)) + } return anthropic.NewClient(anthropicClientOptions...) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 4afac2c70809d6c98e0aa35022c296c3d95ef05e..f55914520774e2fcf5e6283e22365f4ce3621dc1 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -44,10 +44,12 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { } } - if opts.extraHeaders != nil { - for key, value := range opts.extraHeaders { - openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) - } + for key, value := range opts.extraHeaders { + openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) + } + + for extraKey, extraValue := range opts.extraBody { + openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue)) } return openai.NewClient(openaiClientOptions...) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 12dd09392942b0c00e7caa975deefffa994b47b8..412093334169b4c0d59fdd4f3f72b1e427651307 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -70,6 +70,7 @@ type providerClientOptions struct { systemMessage string maxTokens int64 extraHeaders map[string]string + extraBody map[string]any extraParams map[string]string } @@ -147,6 +148,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi config: cfg, apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, + extraBody: cfg.ExtraBody, model: func(tp config.SelectedModelType) provider.Model { return *config.Get().GetModelByType(tp) },