@@ -773,42 +773,48 @@ func (c *Config) Resolver() VariableResolver {
}
func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
- testURL := ""
- headers := make(map[string]string)
- apiKey, _ := resolver.ResolveValue(c.APIKey)
+ var (
+ providerID = catwalk.InferenceProvider(c.ID)
+ testURL = ""
+ headers = make(map[string]string)
+ apiKey, _ = resolver.ResolveValue(c.APIKey)
+ )
+
switch c.Type {
case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter:
baseURL, _ := resolver.ResolveValue(c.BaseURL)
- if baseURL == "" {
- baseURL = "https://api.openai.com/v1"
- }
- if c.ID == string(catwalk.InferenceProviderOpenRouter) {
+ baseURL = cmp.Or(baseURL, "https://api.openai.com/v1")
+
+ switch providerID {
+ case catwalk.InferenceProviderOpenRouter:
testURL = baseURL + "/credits"
- } else {
+ default:
testURL = baseURL + "/models"
}
+
headers["Authorization"] = "Bearer " + apiKey
case catwalk.TypeAnthropic:
baseURL, _ := resolver.ResolveValue(c.BaseURL)
- if baseURL == "" {
- baseURL = "https://api.anthropic.com/v1"
- }
- testURL = baseURL + "/models"
- // TODO: replace with const when catwalk is released
- if c.ID == "kimi-coding" {
+ baseURL = cmp.Or(baseURL, "https://api.anthropic.com/v1")
+
+ switch providerID {
+ case catwalk.InferenceKimiCoding:
testURL = baseURL + "/v1/models"
+ default:
+ testURL = baseURL + "/models"
}
+
headers["x-api-key"] = apiKey
headers["anthropic-version"] = "2023-06-01"
case catwalk.TypeGoogle:
baseURL, _ := resolver.ResolveValue(c.BaseURL)
- if baseURL == "" {
- baseURL = "https://generativelanguage.googleapis.com"
- }
+ baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com")
testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey)
}
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
+
client := &http.Client{}
req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
if err != nil {
@@ -820,17 +826,19 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
for k, v := range c.ExtraHeaders {
req.Header.Set(k, v)
}
+
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
}
defer resp.Body.Close()
- if c.ID == string(catwalk.InferenceProviderZAI) {
+
+ switch providerID {
+ case catwalk.InferenceProviderZAI:
if resp.StatusCode == http.StatusUnauthorized {
- // For z.ai just check if the http response is not 401.
return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
}
- } else {
+ default:
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
}