diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index aa005198420ce3b87695d17d43a5b5c6ef7bd66d..04a49eebe2aeb110cd0cd55421d9b632480e7461 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -725,7 +725,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str return azure.New(opts...) } -func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) { +func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) { var opts []bedrock.Option if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() @@ -734,9 +734,13 @@ func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.P if len(headers) > 0 { opts = append(opts, bedrock.WithHeaders(headers)) } - bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK") - if bearerToken != "" { - opts = append(opts, bedrock.WithAPIKey(bearerToken)) + switch { + case apiKey != "": + opts = append(opts, bedrock.WithAPIKey(apiKey)) + case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "": + opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK"))) + default: + // Skip, let the SDK do authentication. } return bedrock.New(opts...) } @@ -824,7 +828,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con case azure.Name: return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams) case bedrock.Name: - return c.buildBedrockProvider(headers) + return c.buildBedrockProvider(apiKey, headers) case google.Name: return c.buildGoogleProvider(baseURL, apiKey, headers) case "google-vertex": diff --git a/internal/config/config.go b/internal/config/config.go index 8e9b3f0fb7349f4b911c9a6c41fc3e3890f3f19e..340002968924538e543d2489a11a0c6232f897d6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "cmp" "context" + "errors" "fmt" "log/slog" "maps" @@ -584,6 +585,14 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { baseURL, _ := resolver.ResolveValue(c.BaseURL) baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com") testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey) + case catwalk.TypeBedrock: + // NOTE: Bedrock has a `/foundation-models` endpoint that we could in + // theory use, but apparently the authorization is region-specific, + // so it's not so trivial. + if strings.HasPrefix(apiKey, "ABSK") { // Bedrock API keys + return nil + } + return errors.New("not a valid bedrock api key") } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)