diff --git a/internal/config/config.go b/internal/config/config.go index 2de7afbc106c48214716e1338236768e77ed2e97..fcca3d0b3aded2b35d1ed069d3ed379faf4b59c3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -773,42 +773,48 @@ func (c *Config) Resolver() VariableResolver { } func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { - testURL := "" - headers := make(map[string]string) - apiKey, _ := resolver.ResolveValue(c.APIKey) + var ( + providerID = catwalk.InferenceProvider(c.ID) + testURL = "" + headers = make(map[string]string) + apiKey, _ = resolver.ResolveValue(c.APIKey) + ) + switch c.Type { case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter: baseURL, _ := resolver.ResolveValue(c.BaseURL) - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - if c.ID == string(catwalk.InferenceProviderOpenRouter) { + baseURL = cmp.Or(baseURL, "https://api.openai.com/v1") + + switch providerID { + case catwalk.InferenceProviderOpenRouter: testURL = baseURL + "/credits" - } else { + default: testURL = baseURL + "/models" } + headers["Authorization"] = "Bearer " + apiKey case catwalk.TypeAnthropic: baseURL, _ := resolver.ResolveValue(c.BaseURL) - if baseURL == "" { - baseURL = "https://api.anthropic.com/v1" - } - testURL = baseURL + "/models" - // TODO: replace with const when catwalk is released - if c.ID == "kimi-coding" { + baseURL = cmp.Or(baseURL, "https://api.anthropic.com/v1") + + switch providerID { + case catwalk.InferenceKimiCoding: testURL = baseURL + "/v1/models" + default: + testURL = baseURL + "/models" } + headers["x-api-key"] = apiKey headers["anthropic-version"] = "2023-06-01" case catwalk.TypeGoogle: baseURL, _ := resolver.ResolveValue(c.BaseURL) - if baseURL == "" { - baseURL = "https://generativelanguage.googleapis.com" - } + baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com") testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey) } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + client := &http.Client{} req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil) if err != nil { @@ -820,17 +826,19 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { for k, v := range c.ExtraHeaders { req.Header.Set(k, v) } + resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) } defer resp.Body.Close() - if c.ID == string(catwalk.InferenceProviderZAI) { + + switch providerID { + case catwalk.InferenceProviderZAI: if resp.StatusCode == http.StatusUnauthorized { - // For z.ai just check if the http response is not 401. return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status) } - } else { + default: if resp.StatusCode != http.StatusOK { return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status) }