local.go

  1package models
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"net/http"
  8	"net/url"
  9	"os"
 10	"regexp"
 11	"strings"
 12	"unicode"
 13
 14	"github.com/charmbracelet/crush/internal/logging"
 15	"github.com/spf13/viper"
 16)
 17
 18const (
 19	ProviderLocal ModelProvider = "local"
 20
 21	localModelsPath        = "v1/models"
 22	lmStudioBetaModelsPath = "api/v0/models"
 23)
 24
 25func init() {
 26	if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
 27		localEndpoint, err := url.Parse(endpoint)
 28		if err != nil {
 29			logging.Debug("Failed to parse local endpoint",
 30				"error", err,
 31				"endpoint", endpoint,
 32			)
 33			return
 34		}
 35
 36		load := func(url *url.URL, path string) []localModel {
 37			url.Path = path
 38			return listLocalModels(url.String())
 39		}
 40
 41		models := load(localEndpoint, lmStudioBetaModelsPath)
 42
 43		if len(models) == 0 {
 44			models = load(localEndpoint, localModelsPath)
 45		}
 46
 47		if len(models) == 0 {
 48			logging.Debug("No local models found",
 49				"endpoint", endpoint,
 50			)
 51			return
 52		}
 53
 54		loadLocalModels(models)
 55
 56		viper.SetDefault("providers.local.apiKey", "dummy")
 57	}
 58}
 59
 60type localModelList struct {
 61	Data []localModel `json:"data"`
 62}
 63
 64type localModel struct {
 65	ID                  string `json:"id"`
 66	Object              string `json:"object"`
 67	Type                string `json:"type"`
 68	Publisher           string `json:"publisher"`
 69	Arch                string `json:"arch"`
 70	CompatibilityType   string `json:"compatibility_type"`
 71	Quantization        string `json:"quantization"`
 72	State               string `json:"state"`
 73	MaxContextLength    int64  `json:"max_context_length"`
 74	LoadedContextLength int64  `json:"loaded_context_length"`
 75}
 76
 77func listLocalModels(modelsEndpoint string) []localModel {
 78	res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil)
 79	if err != nil {
 80		logging.Debug("Failed to list local models",
 81			"error", err,
 82			"endpoint", modelsEndpoint,
 83		)
 84	}
 85	defer res.Body.Close()
 86
 87	if res.Response.StatusCode != http.StatusOK {
 88		logging.Debug("Failed to list local models",
 89			"status", res.Response.Status,
 90			"endpoint", modelsEndpoint,
 91		)
 92	}
 93
 94	var modelList localModelList
 95	if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
 96		logging.Debug("Failed to list local models",
 97			"error", err,
 98			"endpoint", modelsEndpoint,
 99		)
100	}
101
102	var supportedModels []localModel
103	for _, model := range modelList.Data {
104		if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
105			if model.Object != "model" || model.Type != "llm" {
106				logging.Debug("Skipping unsupported LMStudio model",
107					"endpoint", modelsEndpoint,
108					"id", model.ID,
109					"object", model.Object,
110					"type", model.Type,
111				)
112
113				continue
114			}
115		}
116
117		supportedModels = append(supportedModels, model)
118	}
119
120	return supportedModels
121}
122
123func loadLocalModels(models []localModel) {
124	for i, m := range models {
125		model := convertLocalModel(m)
126		SupportedModels[model.ID] = model
127
128		if i == 0 || m.State == "loaded" {
129			viper.SetDefault("agents.coder.model", model.ID)
130			viper.SetDefault("agents.summarizer.model", model.ID)
131			viper.SetDefault("agents.task.model", model.ID)
132			viper.SetDefault("agents.title.model", model.ID)
133		}
134	}
135}
136
137func convertLocalModel(model localModel) Model {
138	return Model{
139		ID:                  ModelID("local." + model.ID),
140		Name:                friendlyModelName(model.ID),
141		Provider:            ProviderLocal,
142		APIModel:            model.ID,
143		ContextWindow:       cmp.Or(model.LoadedContextLength, 4096),
144		DefaultMaxTokens:    cmp.Or(model.LoadedContextLength, 4096),
145		CanReason:           true,
146		SupportsAttachments: true,
147	}
148}
149
150var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
151
152func friendlyModelName(modelID string) string {
153	mainID := modelID
154	tag := ""
155
156	if slash := strings.LastIndex(mainID, "/"); slash != -1 {
157		mainID = mainID[slash+1:]
158	}
159
160	if at := strings.Index(modelID, "@"); at != -1 {
161		mainID = modelID[:at]
162		tag = modelID[at+1:]
163	}
164
165	match := modelInfoRegex.FindStringSubmatch(mainID)
166	if match == nil {
167		return modelID
168	}
169
170	capitalize := func(s string) string {
171		if s == "" {
172			return ""
173		}
174		runes := []rune(s)
175		runes[0] = unicode.ToUpper(runes[0])
176		return string(runes)
177	}
178
179	family := capitalize(match[1])
180	version := ""
181	label := ""
182
183	if len(match) > 2 && match[2] != "" {
184		version = strings.ToUpper(match[2])
185	}
186
187	if len(match) > 3 && match[3] != "" {
188		label = capitalize(match[3])
189	}
190
191	var parts []string
192	if family != "" {
193		parts = append(parts, family)
194	}
195	if version != "" {
196		parts = append(parts, version)
197	}
198	if label != "" {
199		parts = append(parts, label)
200	}
201	if tag != "" {
202		parts = append(parts, tag)
203	}
204
205	return strings.Join(parts, " ")
206}