1package models
2
3import (
4 "cmp"
5 "encoding/json"
6 "net/http"
7 "net/url"
8 "os"
9 "regexp"
10 "strings"
11 "unicode"
12
13 "github.com/opencode-ai/opencode/internal/logging"
14 "github.com/spf13/viper"
15)
16
17const (
18 ProviderLocal ModelProvider = "local"
19
20 localModelsPath = "v1/models"
21 lmStudioBetaModelsPath = "api/v0/models"
22)
23
24func init() {
25 if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
26 localEndpoint, err := url.Parse(endpoint)
27 if err != nil {
28 logging.Debug("Failed to parse local endpoint",
29 "error", err,
30 "endpoint", endpoint,
31 )
32 return
33 }
34
35 load := func(url *url.URL, path string) []localModel {
36 url.Path = path
37 return listLocalModels(url.String())
38 }
39
40 models := load(localEndpoint, lmStudioBetaModelsPath)
41
42 if len(models) == 0 {
43 models = load(localEndpoint, localModelsPath)
44 }
45
46 if len(models) == 0 {
47 logging.Debug("No local models found",
48 "endpoint", endpoint,
49 )
50 return
51 }
52
53 loadLocalModels(models)
54
55 viper.SetDefault("providers.local.apiKey", "dummy")
56 ProviderPopularity[ProviderLocal] = 0
57 }
58}
59
60type localModelList struct {
61 Data []localModel `json:"data"`
62}
63
64type localModel struct {
65 ID string `json:"id"`
66 Object string `json:"object"`
67 Type string `json:"type"`
68 Publisher string `json:"publisher"`
69 Arch string `json:"arch"`
70 CompatibilityType string `json:"compatibility_type"`
71 Quantization string `json:"quantization"`
72 State string `json:"state"`
73 MaxContextLength int64 `json:"max_context_length"`
74 LoadedContextLength int64 `json:"loaded_context_length"`
75}
76
77func listLocalModels(modelsEndpoint string) []localModel {
78 res, err := http.Get(modelsEndpoint)
79 if err != nil {
80 logging.Debug("Failed to list local models",
81 "error", err,
82 "endpoint", modelsEndpoint,
83 )
84 }
85 defer res.Body.Close()
86
87 if res.StatusCode != http.StatusOK {
88 logging.Debug("Failed to list local models",
89 "status", res.StatusCode,
90 "endpoint", modelsEndpoint,
91 )
92 }
93
94 var modelList localModelList
95 if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
96 logging.Debug("Failed to list local models",
97 "error", err,
98 "endpoint", modelsEndpoint,
99 )
100 }
101
102 var supportedModels []localModel
103 for _, model := range modelList.Data {
104 if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
105 if model.Object != "model" || model.Type != "llm" {
106 logging.Debug("Skipping unsupported LMStudio model",
107 "endpoint", modelsEndpoint,
108 "id", model.ID,
109 "object", model.Object,
110 "type", model.Type,
111 )
112
113 continue
114 }
115 }
116
117 supportedModels = append(supportedModels, model)
118 }
119
120 return supportedModels
121}
122
123func loadLocalModels(models []localModel) {
124 for i, m := range models {
125 model := convertLocalModel(m)
126 SupportedModels[model.ID] = model
127
128 if i == 0 || m.State == "loaded" {
129 viper.SetDefault("agents.coder.model", model.ID)
130 viper.SetDefault("agents.summarizer.model", model.ID)
131 viper.SetDefault("agents.task.model", model.ID)
132 viper.SetDefault("agents.title.model", model.ID)
133 }
134 }
135}
136
137func convertLocalModel(model localModel) Model {
138 return Model{
139 ID: ModelID("local." + model.ID),
140 Name: friendlyModelName(model.ID),
141 Provider: ProviderLocal,
142 APIModel: model.ID,
143 ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
144 DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
145 CanReason: true,
146 SupportsAttachments: true,
147 }
148}
149
150var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
151
152func friendlyModelName(modelID string) string {
153 mainID := modelID
154 tag := ""
155
156 if slash := strings.LastIndex(mainID, "/"); slash != -1 {
157 mainID = mainID[slash+1:]
158 }
159
160 if at := strings.Index(modelID, "@"); at != -1 {
161 mainID = modelID[:at]
162 tag = modelID[at+1:]
163 }
164
165 match := modelInfoRegex.FindStringSubmatch(mainID)
166 if match == nil {
167 return modelID
168 }
169
170 capitalize := func(s string) string {
171 if s == "" {
172 return ""
173 }
174 runes := []rune(s)
175 runes[0] = unicode.ToUpper(runes[0])
176 return string(runes)
177 }
178
179 family := capitalize(match[1])
180 version := ""
181 label := ""
182
183 if len(match) > 2 && match[2] != "" {
184 version = strings.ToUpper(match[2])
185 }
186
187 if len(match) > 3 && match[3] != "" {
188 label = capitalize(match[3])
189 }
190
191 var parts []string
192 if family != "" {
193 parts = append(parts, family)
194 }
195 if version != "" {
196 parts = append(parts, version)
197 }
198 if label != "" {
199 parts = append(parts, label)
200 }
201 if tag != "" {
202 parts = append(parts, tag)
203 }
204
205 return strings.Join(parts, " ")
206}