refactor: cleanup the code a bit

Andrey Nering created

Change summary

internal/config/config.go | 48 +++++++++++++++++++++++-----------------
1 file changed, 28 insertions(+), 20 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -773,42 +773,48 @@ func (c *Config) Resolver() VariableResolver {
 }
 
 func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
-	testURL := ""
-	headers := make(map[string]string)
-	apiKey, _ := resolver.ResolveValue(c.APIKey)
+	var (
+		providerID = catwalk.InferenceProvider(c.ID)
+		testURL    = ""
+		headers    = make(map[string]string)
+		apiKey, _  = resolver.ResolveValue(c.APIKey)
+	)
+
 	switch c.Type {
 	case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter:
 		baseURL, _ := resolver.ResolveValue(c.BaseURL)
-		if baseURL == "" {
-			baseURL = "https://api.openai.com/v1"
-		}
-		if c.ID == string(catwalk.InferenceProviderOpenRouter) {
+		baseURL = cmp.Or(baseURL, "https://api.openai.com/v1")
+
+		switch providerID {
+		case catwalk.InferenceProviderOpenRouter:
 			testURL = baseURL + "/credits"
-		} else {
+		default:
 			testURL = baseURL + "/models"
 		}
+
 		headers["Authorization"] = "Bearer " + apiKey
 	case catwalk.TypeAnthropic:
 		baseURL, _ := resolver.ResolveValue(c.BaseURL)
-		if baseURL == "" {
-			baseURL = "https://api.anthropic.com/v1"
-		}
-		testURL = baseURL + "/models"
-		// TODO: replace with const when catwalk is released
-		if c.ID == "kimi-coding" {
+		baseURL = cmp.Or(baseURL, "https://api.anthropic.com/v1")
+
+		switch providerID {
+		case catwalk.InferenceKimiCoding:
 			testURL = baseURL + "/v1/models"
+		default:
+			testURL = baseURL + "/models"
 		}
+
 		headers["x-api-key"] = apiKey
 		headers["anthropic-version"] = "2023-06-01"
 	case catwalk.TypeGoogle:
 		baseURL, _ := resolver.ResolveValue(c.BaseURL)
-		if baseURL == "" {
-			baseURL = "https://generativelanguage.googleapis.com"
-		}
+		baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com")
 		testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey)
 	}
+
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
+
 	client := &http.Client{}
 	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
 	if err != nil {
@@ -820,17 +826,19 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
 	for k, v := range c.ExtraHeaders {
 		req.Header.Set(k, v)
 	}
+
 	resp, err := client.Do(req)
 	if err != nil {
 		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
 	}
 	defer resp.Body.Close()
-	if c.ID == string(catwalk.InferenceProviderZAI) {
+
+	switch providerID {
+	case catwalk.InferenceProviderZAI:
 		if resp.StatusCode == http.StatusUnauthorized {
-			// For z.ai just check if the http response is not 401.
 			return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
 		}
-	} else {
+	default:
 		if resp.StatusCode != http.StatusOK {
 			return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
 		}