client.go

  1package ollama
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"net/http"
  8	"time"
  9
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11)
 12
 13const (
 14	defaultOllamaURL = "http://localhost:11434"
 15	requestTimeout   = 2 * time.Second
 16)
 17
 18// IsRunning checks if Ollama is running by attempting to connect to its API
 19func IsRunning(ctx context.Context) bool {
 20	client := &http.Client{
 21		Timeout: requestTimeout,
 22	}
 23
 24	req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/tags", nil)
 25	if err != nil {
 26		return false
 27	}
 28
 29	resp, err := client.Do(req)
 30	if err != nil {
 31		return false
 32	}
 33	defer resp.Body.Close()
 34
 35	return resp.StatusCode == http.StatusOK
 36}
 37
 38// GetModels retrieves available models from Ollama
 39func GetModels(ctx context.Context) ([]provider.Model, error) {
 40	client := &http.Client{
 41		Timeout: requestTimeout,
 42	}
 43
 44	req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/tags", nil)
 45	if err != nil {
 46		return nil, fmt.Errorf("failed to create request: %w", err)
 47	}
 48
 49	resp, err := client.Do(req)
 50	if err != nil {
 51		return nil, fmt.Errorf("failed to connect to Ollama: %w", err)
 52	}
 53	defer resp.Body.Close()
 54
 55	if resp.StatusCode != http.StatusOK {
 56		return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode)
 57	}
 58
 59	var tagsResponse OllamaTagsResponse
 60	if err := json.NewDecoder(resp.Body).Decode(&tagsResponse); err != nil {
 61		return nil, fmt.Errorf("failed to decode response: %w", err)
 62	}
 63
 64	models := make([]provider.Model, len(tagsResponse.Models))
 65	for i, ollamaModel := range tagsResponse.Models {
 66		models[i] = provider.Model{
 67			ID:                 ollamaModel.Name,
 68			Model:              ollamaModel.Name,
 69			CostPer1MIn:        0, // Local models have no cost
 70			CostPer1MOut:       0,
 71			CostPer1MInCached:  0,
 72			CostPer1MOutCached: 0,
 73			ContextWindow:      getContextWindow(ollamaModel.Details.Family),
 74			DefaultMaxTokens:   4096,
 75			CanReason:          false,
 76			HasReasoningEffort: false,
 77			SupportsImages:     supportsImages(ollamaModel.Details.Family),
 78		}
 79	}
 80
 81	return models, nil
 82}
 83
 84// GetRunningModels returns models that are currently loaded in memory
 85func GetRunningModels(ctx context.Context) ([]OllamaRunningModel, error) {
 86	client := &http.Client{
 87		Timeout: requestTimeout,
 88	}
 89
 90	req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/ps", nil)
 91	if err != nil {
 92		return nil, fmt.Errorf("failed to create request: %w", err)
 93	}
 94
 95	resp, err := client.Do(req)
 96	if err != nil {
 97		return nil, fmt.Errorf("failed to connect to Ollama: %w", err)
 98	}
 99	defer resp.Body.Close()
100
101	if resp.StatusCode != http.StatusOK {
102		return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode)
103	}
104
105	var psResponse OllamaRunningModelsResponse
106	if err := json.NewDecoder(resp.Body).Decode(&psResponse); err != nil {
107		return nil, fmt.Errorf("failed to decode response: %w", err)
108	}
109
110	return psResponse.Models, nil
111}
112
113// IsModelLoaded checks if a specific model is currently loaded in memory
114func IsModelLoaded(ctx context.Context, modelName string) (bool, error) {
115	runningModels, err := GetRunningModels(ctx)
116	if err != nil {
117		return false, err
118	}
119
120	for _, model := range runningModels {
121		if model.Name == modelName {
122			return true, nil
123		}
124	}
125
126	return false, nil
127}
128
129// GetProvider returns a provider.Provider for Ollama if it's running
130func GetProvider(ctx context.Context) (*provider.Provider, error) {
131	if !IsRunning(ctx) {
132		return nil, fmt.Errorf("Ollama is not running")
133	}
134
135	models, err := GetModels(ctx)
136	if err != nil {
137		return nil, fmt.Errorf("failed to get models: %w", err)
138	}
139
140	return &provider.Provider{
141		Name:   "Ollama",
142		ID:     "ollama",
143		Models: models,
144	}, nil
145}
146
147// getContextWindow returns an estimated context window based on model family
148func getContextWindow(family string) int64 {
149	switch family {
150	case "llama":
151		return 131072 // Llama 3.x context window
152	case "mistral":
153		return 32768
154	case "gemma":
155		return 8192
156	case "qwen", "qwen2":
157		return 131072
158	case "phi":
159		return 131072
160	case "codellama":
161		return 16384
162	default:
163		return 8192 // Conservative default
164	}
165}
166
167// supportsImages returns whether a model family supports image inputs
168func supportsImages(family string) bool {
169	switch family {
170	case "llama-vision", "llava":
171		return true
172	default:
173		return false
174	}
175}