1package models
2
3import (
4 "cmp"
5 "encoding/json"
6 "net/http"
7 "net/url"
8 "os"
9 "regexp"
10 "strings"
11 "unicode"
12
13 "github.com/opencode-ai/opencode/internal/logging"
14 "github.com/spf13/viper"
15)
16
17const (
18 ProviderLocal ModelProvider = "local"
19
20 localModelsPath = "v1/models"
21 lmStudioBetaModelsPath = "api/v0/models"
22)
23
24func init() {
25 if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
26 localEndpoint, err := url.Parse(endpoint)
27 if err != nil {
28 logging.Debug("Failed to parse local endpoint",
29 "error", err,
30 "endpoint", endpoint,
31 )
32 return
33 }
34
35 load := func(url *url.URL, path string) []localModel {
36 url.Path = path
37 return listLocalModels(url.String())
38 }
39
40 models := load(localEndpoint, lmStudioBetaModelsPath)
41
42 if len(models) == 0 {
43 models = load(localEndpoint, localModelsPath)
44 }
45
46 if len(models) == 0 {
47 logging.Debug("No local models found",
48 "endpoint", endpoint,
49 )
50 return
51 }
52
53 loadLocalModels(models)
54
55 viper.SetDefault("providers.local.apiKey", "dummy")
56 ProviderPopularity[ProviderLocal] = 0
57 }
58}
59
60type localModelList struct {
61 Data []localModel `json:"data"`
62}
63
64type localModel struct {
65 ID string `json:"id"`
66 Object string `json:"object"`
67 Type string `json:"type"`
68 Publisher string `json:"publisher"`
69 Arch string `json:"arch"`
70 CompatibilityType string `json:"compatibility_type"`
71 Quantization string `json:"quantization"`
72 State string `json:"state"`
73 MaxContextLength int64 `json:"max_context_length"`
74 LoadedContextLength int64 `json:"loaded_context_length"`
75}
76
77func listLocalModels(modelsEndpoint string) []localModel {
78 res, err := http.Get(modelsEndpoint)
79 if err != nil {
80 logging.Debug("Failed to list local models",
81 "error", err,
82 "endpoint", modelsEndpoint,
83 )
84 return []localModel{}
85 }
86 defer res.Body.Close()
87
88 if res.StatusCode != http.StatusOK {
89 logging.Debug("Failed to list local models",
90 "status", res.StatusCode,
91 "endpoint", modelsEndpoint,
92 )
93 return []localModel{}
94 }
95
96 var modelList localModelList
97 if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
98 logging.Debug("Failed to list local models",
99 "error", err,
100 "endpoint", modelsEndpoint,
101 )
102 return []localModel{}
103 }
104
105 var supportedModels []localModel
106 for _, model := range modelList.Data {
107 if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
108 if model.Object != "model" || model.Type != "llm" {
109 logging.Debug("Skipping unsupported LMStudio model",
110 "endpoint", modelsEndpoint,
111 "id", model.ID,
112 "object", model.Object,
113 "type", model.Type,
114 )
115
116 continue
117 }
118 }
119
120 supportedModels = append(supportedModels, model)
121 }
122
123 return supportedModels
124}
125
126func loadLocalModels(models []localModel) {
127 for i, m := range models {
128 model := convertLocalModel(m)
129 SupportedModels[model.ID] = model
130
131 if i == 0 || m.State == "loaded" {
132 viper.SetDefault("agents.coder.model", model.ID)
133 viper.SetDefault("agents.summarizer.model", model.ID)
134 viper.SetDefault("agents.task.model", model.ID)
135 viper.SetDefault("agents.title.model", model.ID)
136 }
137 }
138}
139
140func convertLocalModel(model localModel) Model {
141 return Model{
142 ID: ModelID("local." + model.ID),
143 Name: friendlyModelName(model.ID),
144 Provider: ProviderLocal,
145 APIModel: model.ID,
146 ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
147 DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
148 CanReason: true,
149 SupportsAttachments: true,
150 }
151}
152
153var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
154
155func friendlyModelName(modelID string) string {
156 mainID := modelID
157 tag := ""
158
159 if slash := strings.LastIndex(mainID, "/"); slash != -1 {
160 mainID = mainID[slash+1:]
161 }
162
163 if at := strings.Index(modelID, "@"); at != -1 {
164 mainID = modelID[:at]
165 tag = modelID[at+1:]
166 }
167
168 match := modelInfoRegex.FindStringSubmatch(mainID)
169 if match == nil {
170 return modelID
171 }
172
173 capitalize := func(s string) string {
174 if s == "" {
175 return ""
176 }
177 runes := []rune(s)
178 runes[0] = unicode.ToUpper(runes[0])
179 return string(runes)
180 }
181
182 family := capitalize(match[1])
183 version := ""
184 label := ""
185
186 if len(match) > 2 && match[2] != "" {
187 version = strings.ToUpper(match[2])
188 }
189
190 if len(match) > 3 && match[3] != "" {
191 label = capitalize(match[3])
192 }
193
194 var parts []string
195 if family != "" {
196 parts = append(parts, family)
197 }
198 if version != "" {
199 parts = append(parts, version)
200 }
201 if label != "" {
202 parts = append(parts, label)
203 }
204 if tag != "" {
205 parts = append(parts, tag)
206 }
207
208 return strings.Join(parts, " ")
209}