From a2524a1d059199cad2c9c5c06ea576f9605d0147 Mon Sep 17 00:00:00 2001 From: Pietjan Oostra Date: Sun, 11 May 2025 19:43:49 +0200 Subject: [PATCH] Add local provider --- internal/config/config.go | 2 +- internal/llm/agent/agent.go | 2 +- internal/llm/models/local.go | 191 ++++++++++++++++++++++++++++++ internal/llm/provider/provider.go | 10 +- 4 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 internal/llm/models/local.go diff --git a/internal/config/config.go b/internal/config/config.go index 351bc501ace11b8de763975fbedda4ab28e7cb1a..5a0905bba239c0d7c79f669801ef9b3a5caa9cf9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -526,7 +526,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error { } // Validate reasoning effort for models that support reasoning - if model.CanReason && provider == models.ProviderOpenAI { + if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal { if agent.ReasoningEffort == "" { // Set default reasoning effort for models that support it logging.Info("setting default reasoning effort for model that supports reasoning", diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 0ac7f65ff37f2cbab3fd45e5ea963542f431e75c..4f31fe75d688aa2c4fdd80a4f633fe35d45125cc 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -715,7 +715,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), provider.WithMaxTokens(maxTokens), } - if model.Provider == models.ProviderOpenAI && model.CanReason { + if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason { opts = append( opts, provider.WithOpenAIOptions( diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go new file mode 100644 index 0000000000000000000000000000000000000000..252f6a9f95a8f023b4f23275e37b08b1d66f8862 --- /dev/null +++ b/internal/llm/models/local.go @@ -0,0 +1,191 @@ +package models + +import ( + "cmp" + "encoding/json" + "log/slog" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "unicode" + + "github.com/spf13/viper" +) + +const ( + ProviderLocal ModelProvider = "local" + + localModelsPath = "v1/models" + lmStudioBetaModelsPath = "api/v0/models" +) + +func init() { + if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" { + localEndpoint, err := url.Parse(endpoint) + if err != nil { + slog.Debug("Failed to parse local endpoint", + "error", err, + "endpoint", endpoint, + ) + return + } + + load := func(url *url.URL, path string) []localModel { + url.Path = path + return listLocalModels(url.String()) + } + + models := load(localEndpoint, lmStudioBetaModelsPath) + + if len(models) == 0 { + models = load(localEndpoint, localModelsPath) + } + + if len(models) == 0 { + slog.Debug("No local models found", + "endpoint", endpoint, + ) + return + } + + loadLocalModels(models) + + viper.SetDefault("providers.local.apiKey", "dummy") + ProviderPopularity[ProviderLocal] = 0 + } +} + +type localModelList struct { + Data []localModel `json:"data"` +} + +type localModel struct { + ID string `json:"id"` + Object string `json:"object"` + Type string `json:"type"` + Publisher string `json:"publisher"` + Arch string `json:"arch"` + CompatibilityType string `json:"compatibility_type"` + Quantization string `json:"quantization"` + State string `json:"state"` + MaxContextLength int64 `json:"max_context_length"` + LoadedContextLength int64 `json:"loaded_context_length"` +} + +func listLocalModels(modelsEndpoint string) []localModel { + res, err := http.Get(modelsEndpoint) + if err != nil { + slog.Debug("Failed to list local models", + "error", err, + "endpoint", modelsEndpoint, + ) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + slog.Debug("Failed to list local models", + "status", res.StatusCode, + "endpoint", modelsEndpoint, + ) + } + + var modelList localModelList + if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil { + slog.Debug("Failed to list local models", + "error", err, + "endpoint", modelsEndpoint, + ) + } + + var supportedModels []localModel + for _, model := range modelList.Data { + if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) { + if model.Object != "model" || model.Type != "llm" { + slog.Debug("Skipping unsupported LMStudio model", + "endpoint", modelsEndpoint, + "id", model.ID, + "object", model.Object, + "type", model.Type, + ) + + continue + } + } + + supportedModels = append(supportedModels, model) + } + + return supportedModels +} + +func loadLocalModels(models []localModel) { + for i, m := range models { + model := convertLocalModel(m) + SupportedModels[model.ID] = model + + if i == 1 || m.State == "loaded" { + viper.SetDefault("agents.coder.model", model.ID) + viper.SetDefault("agents.summarizer.model", model.ID) + viper.SetDefault("agents.task.model", model.ID) + viper.SetDefault("agents.title.model", model.ID) + } + } +} + +func convertLocalModel(model localModel) Model { + return Model{ + ID: ModelID("local." + model.ID), + Name: friendlyModelName(model.ID), + Provider: ProviderLocal, + APIModel: model.ID, + ContextWindow: cmp.Or(model.LoadedContextLength, 4096), + DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096), + CanReason: true, + SupportsAttachments: true, + } +} + +var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`) + +func friendlyModelName(modelID string) string { + match := modelInfoRegex.FindStringSubmatch(modelID) + if match == nil { + return modelID + } + + capitalize := func(s string) string { + if s == "" { + return "" + } + runes := []rune(s) + runes[0] = unicode.ToUpper(runes[0]) + return string(runes) + } + + family := capitalize(match[1]) + version := "" + label := "" + + if len(match) > 2 && match[2] != "" { + version = strings.ToUpper(match[2]) + } + + if len(match) > 3 && match[3] != "" { + label = capitalize(match[3]) + } + + var parts []string + if family != "" { + parts = append(parts, family) + } + if version != "" { + parts = append(parts, version) + } + if label != "" { + parts = append(parts, label) + } + + return strings.Join(parts, " ") +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 4f5164c918df9247be0e75097d19afd3636270d5..08175450a6d85953e996c08f436982a1981053b6 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "os" "github.com/opencode-ai/opencode/internal/llm/models" "github.com/opencode-ai/opencode/internal/llm/tools" @@ -145,7 +146,14 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption options: clientOptions, client: newOpenAIClient(clientOptions), }, nil - + case models.ProviderLocal: + clientOptions.openaiOptions = append(clientOptions.openaiOptions, + WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")), + ) + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil case models.ProviderMock: // TODO: implement mock client for test panic("not implemented")