1package ollama
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "time"
9
10 "github.com/charmbracelet/crush/internal/fur/provider"
11)
12
13const (
14 defaultOllamaURL = "http://localhost:11434"
15 requestTimeout = 2 * time.Second
16)
17
18// IsRunning checks if Ollama is running by attempting to connect to its API
19func IsRunning(ctx context.Context) bool {
20 client := &http.Client{
21 Timeout: requestTimeout,
22 }
23
24 req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/tags", nil)
25 if err != nil {
26 return false
27 }
28
29 resp, err := client.Do(req)
30 if err != nil {
31 return false
32 }
33 defer resp.Body.Close()
34
35 return resp.StatusCode == http.StatusOK
36}
37
38// GetModels retrieves available models from Ollama
39func GetModels(ctx context.Context) ([]provider.Model, error) {
40 client := &http.Client{
41 Timeout: requestTimeout,
42 }
43
44 req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/tags", nil)
45 if err != nil {
46 return nil, fmt.Errorf("failed to create request: %w", err)
47 }
48
49 resp, err := client.Do(req)
50 if err != nil {
51 return nil, fmt.Errorf("failed to connect to Ollama: %w", err)
52 }
53 defer resp.Body.Close()
54
55 if resp.StatusCode != http.StatusOK {
56 return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode)
57 }
58
59 var tagsResponse OllamaTagsResponse
60 if err := json.NewDecoder(resp.Body).Decode(&tagsResponse); err != nil {
61 return nil, fmt.Errorf("failed to decode response: %w", err)
62 }
63
64 models := make([]provider.Model, len(tagsResponse.Models))
65 for i, ollamaModel := range tagsResponse.Models {
66 models[i] = provider.Model{
67 ID: ollamaModel.Name,
68 Model: ollamaModel.Name,
69 CostPer1MIn: 0, // Local models have no cost
70 CostPer1MOut: 0,
71 CostPer1MInCached: 0,
72 CostPer1MOutCached: 0,
73 ContextWindow: getContextWindow(ollamaModel.Details.Family),
74 DefaultMaxTokens: 4096,
75 CanReason: false,
76 HasReasoningEffort: false,
77 SupportsImages: supportsImages(ollamaModel.Details.Family),
78 }
79 }
80
81 return models, nil
82}
83
84// GetRunningModels returns models that are currently loaded in memory
85func GetRunningModels(ctx context.Context) ([]OllamaRunningModel, error) {
86 client := &http.Client{
87 Timeout: requestTimeout,
88 }
89
90 req, err := http.NewRequestWithContext(ctx, "GET", defaultOllamaURL+"/api/ps", nil)
91 if err != nil {
92 return nil, fmt.Errorf("failed to create request: %w", err)
93 }
94
95 resp, err := client.Do(req)
96 if err != nil {
97 return nil, fmt.Errorf("failed to connect to Ollama: %w", err)
98 }
99 defer resp.Body.Close()
100
101 if resp.StatusCode != http.StatusOK {
102 return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode)
103 }
104
105 var psResponse OllamaRunningModelsResponse
106 if err := json.NewDecoder(resp.Body).Decode(&psResponse); err != nil {
107 return nil, fmt.Errorf("failed to decode response: %w", err)
108 }
109
110 return psResponse.Models, nil
111}
112
113// IsModelLoaded checks if a specific model is currently loaded in memory
114func IsModelLoaded(ctx context.Context, modelName string) (bool, error) {
115 runningModels, err := GetRunningModels(ctx)
116 if err != nil {
117 return false, err
118 }
119
120 for _, model := range runningModels {
121 if model.Name == modelName {
122 return true, nil
123 }
124 }
125
126 return false, nil
127}
128
129// GetProvider returns a provider.Provider for Ollama if it's running
130func GetProvider(ctx context.Context) (*provider.Provider, error) {
131 if !IsRunning(ctx) {
132 return nil, fmt.Errorf("Ollama is not running")
133 }
134
135 models, err := GetModels(ctx)
136 if err != nil {
137 return nil, fmt.Errorf("failed to get models: %w", err)
138 }
139
140 return &provider.Provider{
141 Name: "Ollama",
142 ID: "ollama",
143 Models: models,
144 }, nil
145}
146
147// getContextWindow returns an estimated context window based on model family
148func getContextWindow(family string) int64 {
149 switch family {
150 case "llama":
151 return 131072 // Llama 3.x context window
152 case "mistral":
153 return 32768
154 case "gemma":
155 return 8192
156 case "qwen", "qwen2":
157 return 131072
158 case "phi":
159 return 131072
160 case "codellama":
161 return 16384
162 default:
163 return 8192 // Conservative default
164 }
165}
166
167// supportsImages returns whether a model family supports image inputs
168func supportsImages(family string) bool {
169 switch family {
170 case "llama-vision", "llava":
171 return true
172 default:
173 return false
174 }
175}