local.go

  1package models
  2
  3import (
  4	"cmp"
  5	"encoding/json"
  6	"net/http"
  7	"net/url"
  8	"os"
  9	"regexp"
 10	"strings"
 11	"unicode"
 12
 13	"github.com/opencode-ai/opencode/internal/logging"
 14	"github.com/spf13/viper"
 15)
 16
 17const (
 18	ProviderLocal ModelProvider = "local"
 19
 20	localModelsPath        = "v1/models"
 21	lmStudioBetaModelsPath = "api/v0/models"
 22)
 23
 24func init() {
 25	if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
 26		localEndpoint, err := url.Parse(endpoint)
 27		if err != nil {
 28			logging.Debug("Failed to parse local endpoint",
 29				"error", err,
 30				"endpoint", endpoint,
 31			)
 32			return
 33		}
 34
 35		load := func(url *url.URL, path string) []localModel {
 36			url.Path = path
 37			return listLocalModels(url.String())
 38		}
 39
 40		models := load(localEndpoint, lmStudioBetaModelsPath)
 41
 42		if len(models) == 0 {
 43			models = load(localEndpoint, localModelsPath)
 44		}
 45
 46		if len(models) == 0 {
 47			logging.Debug("No local models found",
 48				"endpoint", endpoint,
 49			)
 50			return
 51		}
 52
 53		loadLocalModels(models)
 54
 55		viper.SetDefault("providers.local.apiKey", "dummy")
 56		ProviderPopularity[ProviderLocal] = 0
 57	}
 58}
 59
 60type localModelList struct {
 61	Data []localModel `json:"data"`
 62}
 63
 64type localModel struct {
 65	ID                  string `json:"id"`
 66	Object              string `json:"object"`
 67	Type                string `json:"type"`
 68	Publisher           string `json:"publisher"`
 69	Arch                string `json:"arch"`
 70	CompatibilityType   string `json:"compatibility_type"`
 71	Quantization        string `json:"quantization"`
 72	State               string `json:"state"`
 73	MaxContextLength    int64  `json:"max_context_length"`
 74	LoadedContextLength int64  `json:"loaded_context_length"`
 75}
 76
 77func listLocalModels(modelsEndpoint string) []localModel {
 78	res, err := http.Get(modelsEndpoint)
 79	if err != nil {
 80		logging.Debug("Failed to list local models",
 81			"error", err,
 82			"endpoint", modelsEndpoint,
 83		)
 84	}
 85	defer res.Body.Close()
 86
 87	if res.StatusCode != http.StatusOK {
 88		logging.Debug("Failed to list local models",
 89			"status", res.StatusCode,
 90			"endpoint", modelsEndpoint,
 91		)
 92	}
 93
 94	var modelList localModelList
 95	if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
 96		logging.Debug("Failed to list local models",
 97			"error", err,
 98			"endpoint", modelsEndpoint,
 99		)
100	}
101
102	var supportedModels []localModel
103	for _, model := range modelList.Data {
104		if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
105			if model.Object != "model" || model.Type != "llm" {
106				logging.Debug("Skipping unsupported LMStudio model",
107					"endpoint", modelsEndpoint,
108					"id", model.ID,
109					"object", model.Object,
110					"type", model.Type,
111				)
112
113				continue
114			}
115		}
116
117		supportedModels = append(supportedModels, model)
118	}
119
120	return supportedModels
121}
122
123func loadLocalModels(models []localModel) {
124	for i, m := range models {
125		model := convertLocalModel(m)
126		SupportedModels[model.ID] = model
127
128		if i == 0 || m.State == "loaded" {
129			viper.SetDefault("agents.coder.model", model.ID)
130			viper.SetDefault("agents.summarizer.model", model.ID)
131			viper.SetDefault("agents.task.model", model.ID)
132			viper.SetDefault("agents.title.model", model.ID)
133		}
134	}
135}
136
137func convertLocalModel(model localModel) Model {
138	return Model{
139		ID:                  ModelID("local." + model.ID),
140		Name:                friendlyModelName(model.ID),
141		Provider:            ProviderLocal,
142		APIModel:            model.ID,
143		ContextWindow:       cmp.Or(model.LoadedContextLength, 4096),
144		DefaultMaxTokens:    cmp.Or(model.LoadedContextLength, 4096),
145		CanReason:           true,
146		SupportsAttachments: true,
147	}
148}
149
150var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
151
152func friendlyModelName(modelID string) string {
153	mainID := modelID
154	tag := ""
155
156	if slash := strings.LastIndex(mainID, "/"); slash != -1 {
157		mainID = mainID[slash+1:]
158	}
159
160	if at := strings.Index(modelID, "@"); at != -1 {
161		mainID = modelID[:at]
162		tag = modelID[at+1:]
163	}
164
165	match := modelInfoRegex.FindStringSubmatch(mainID)
166	if match == nil {
167		return modelID
168	}
169
170	capitalize := func(s string) string {
171		if s == "" {
172			return ""
173		}
174		runes := []rune(s)
175		runes[0] = unicode.ToUpper(runes[0])
176		return string(runes)
177	}
178
179	family := capitalize(match[1])
180	version := ""
181	label := ""
182
183	if len(match) > 2 && match[2] != "" {
184		version = strings.ToUpper(match[2])
185	}
186
187	if len(match) > 3 && match[3] != "" {
188		label = capitalize(match[3])
189	}
190
191	var parts []string
192	if family != "" {
193		parts = append(parts, family)
194	}
195	if version != "" {
196		parts = append(parts, version)
197	}
198	if label != "" {
199		parts = append(parts, label)
200	}
201	if tag != "" {
202		parts = append(parts, tag)
203	}
204
205	return strings.Join(parts, " ")
206}