local.go

  1package models
  2
  3import (
  4	"cmp"
  5	"encoding/json"
  6	"net/http"
  7	"net/url"
  8	"os"
  9	"regexp"
 10	"strings"
 11	"unicode"
 12
 13	"github.com/opencode-ai/opencode/internal/logging"
 14	"github.com/spf13/viper"
 15)
 16
 17const (
 18	ProviderLocal ModelProvider = "local"
 19
 20	localModelsPath        = "v1/models"
 21	lmStudioBetaModelsPath = "api/v0/models"
 22)
 23
 24func init() {
 25	if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
 26		localEndpoint, err := url.Parse(endpoint)
 27		if err != nil {
 28			logging.Debug("Failed to parse local endpoint",
 29				"error", err,
 30				"endpoint", endpoint,
 31			)
 32			return
 33		}
 34
 35		load := func(url *url.URL, path string) []localModel {
 36			url.Path = path
 37			return listLocalModels(url.String())
 38		}
 39
 40		models := load(localEndpoint, lmStudioBetaModelsPath)
 41
 42		if len(models) == 0 {
 43			models = load(localEndpoint, localModelsPath)
 44		}
 45
 46		if len(models) == 0 {
 47			logging.Debug("No local models found",
 48				"endpoint", endpoint,
 49			)
 50			return
 51		}
 52
 53		loadLocalModels(models)
 54
 55		viper.SetDefault("providers.local.apiKey", "dummy")
 56		ProviderPopularity[ProviderLocal] = 0
 57	}
 58}
 59
 60type localModelList struct {
 61	Data []localModel `json:"data"`
 62}
 63
 64type localModel struct {
 65	ID                  string `json:"id"`
 66	Object              string `json:"object"`
 67	Type                string `json:"type"`
 68	Publisher           string `json:"publisher"`
 69	Arch                string `json:"arch"`
 70	CompatibilityType   string `json:"compatibility_type"`
 71	Quantization        string `json:"quantization"`
 72	State               string `json:"state"`
 73	MaxContextLength    int64  `json:"max_context_length"`
 74	LoadedContextLength int64  `json:"loaded_context_length"`
 75}
 76
 77func listLocalModels(modelsEndpoint string) []localModel {
 78	res, err := http.Get(modelsEndpoint)
 79	if err != nil {
 80		logging.Debug("Failed to list local models",
 81			"error", err,
 82			"endpoint", modelsEndpoint,
 83		)
 84		return []localModel{}
 85	}
 86	defer res.Body.Close()
 87
 88	if res.StatusCode != http.StatusOK {
 89		logging.Debug("Failed to list local models",
 90			"status", res.StatusCode,
 91			"endpoint", modelsEndpoint,
 92		)
 93		return []localModel{}
 94	}
 95
 96	var modelList localModelList
 97	if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
 98		logging.Debug("Failed to list local models",
 99			"error", err,
100			"endpoint", modelsEndpoint,
101		)
102		return []localModel{}
103	}
104
105	var supportedModels []localModel
106	for _, model := range modelList.Data {
107		if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
108			if model.Object != "model" || model.Type != "llm" {
109				logging.Debug("Skipping unsupported LMStudio model",
110					"endpoint", modelsEndpoint,
111					"id", model.ID,
112					"object", model.Object,
113					"type", model.Type,
114				)
115
116				continue
117			}
118		}
119
120		supportedModels = append(supportedModels, model)
121	}
122
123	return supportedModels
124}
125
126func loadLocalModels(models []localModel) {
127	for i, m := range models {
128		model := convertLocalModel(m)
129		SupportedModels[model.ID] = model
130
131		if i == 0 || m.State == "loaded" {
132			viper.SetDefault("agents.coder.model", model.ID)
133			viper.SetDefault("agents.summarizer.model", model.ID)
134			viper.SetDefault("agents.task.model", model.ID)
135			viper.SetDefault("agents.title.model", model.ID)
136		}
137	}
138}
139
140func convertLocalModel(model localModel) Model {
141	return Model{
142		ID:                  ModelID("local." + model.ID),
143		Name:                friendlyModelName(model.ID),
144		Provider:            ProviderLocal,
145		APIModel:            model.ID,
146		ContextWindow:       cmp.Or(model.LoadedContextLength, 4096),
147		DefaultMaxTokens:    cmp.Or(model.LoadedContextLength, 4096),
148		CanReason:           true,
149		SupportsAttachments: true,
150	}
151}
152
153var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
154
155func friendlyModelName(modelID string) string {
156	mainID := modelID
157	tag := ""
158
159	if slash := strings.LastIndex(mainID, "/"); slash != -1 {
160		mainID = mainID[slash+1:]
161	}
162
163	if at := strings.Index(modelID, "@"); at != -1 {
164		mainID = modelID[:at]
165		tag = modelID[at+1:]
166	}
167
168	match := modelInfoRegex.FindStringSubmatch(mainID)
169	if match == nil {
170		return modelID
171	}
172
173	capitalize := func(s string) string {
174		if s == "" {
175			return ""
176		}
177		runes := []rune(s)
178		runes[0] = unicode.ToUpper(runes[0])
179		return string(runes)
180	}
181
182	family := capitalize(match[1])
183	version := ""
184	label := ""
185
186	if len(match) > 2 && match[2] != "" {
187		version = strings.ToUpper(match[2])
188	}
189
190	if len(match) > 3 && match[3] != "" {
191		label = capitalize(match[3])
192	}
193
194	var parts []string
195	if family != "" {
196		parts = append(parts, family)
197	}
198	if version != "" {
199		parts = append(parts, version)
200	}
201	if label != "" {
202		parts = append(parts, label)
203	}
204	if tag != "" {
205		parts = append(parts, tag)
206	}
207
208	return strings.Join(parts, " ")
209}