Add local provider

Pietjan Oostra created

Change summary

internal/config/config.go         |   2 
internal/llm/agent/agent.go       |   2 
internal/llm/models/local.go      | 191 +++++++++++++++++++++++++++++++++
internal/llm/provider/provider.go |  10 +
4 files changed, 202 insertions(+), 3 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -526,7 +526,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
 	}
 
 	// Validate reasoning effort for models that support reasoning
-	if model.CanReason && provider == models.ProviderOpenAI {
+	if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal {
 		if agent.ReasoningEffort == "" {
 			// Set default reasoning effort for models that support it
 			logging.Info("setting default reasoning effort for model that supports reasoning",

internal/llm/agent/agent.go 🔗

@@ -715,7 +715,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
 		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
 		provider.WithMaxTokens(maxTokens),
 	}
-	if model.Provider == models.ProviderOpenAI && model.CanReason {
+	if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
 		opts = append(
 			opts,
 			provider.WithOpenAIOptions(

internal/llm/models/local.go 🔗

@@ -0,0 +1,191 @@
+package models
+
+import (
+	"cmp"
+	"encoding/json"
+	"log/slog"
+	"net/http"
+	"net/url"
+	"os"
+	"regexp"
+	"strings"
+	"unicode"
+
+	"github.com/spf13/viper"
+)
+
+const (
+	ProviderLocal ModelProvider = "local"
+
+	localModelsPath        = "v1/models"
+	lmStudioBetaModelsPath = "api/v0/models"
+)
+
+func init() {
+	if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
+		localEndpoint, err := url.Parse(endpoint)
+		if err != nil {
+			slog.Debug("Failed to parse local endpoint",
+				"error", err,
+				"endpoint", endpoint,
+			)
+			return
+		}
+
+		load := func(url *url.URL, path string) []localModel {
+			url.Path = path
+			return listLocalModels(url.String())
+		}
+
+		models := load(localEndpoint, lmStudioBetaModelsPath)
+
+		if len(models) == 0 {
+			models = load(localEndpoint, localModelsPath)
+		}
+
+		if len(models) == 0 {
+			slog.Debug("No local models found",
+				"endpoint", endpoint,
+			)
+			return
+		}
+
+		loadLocalModels(models)
+
+		viper.SetDefault("providers.local.apiKey", "dummy")
+		ProviderPopularity[ProviderLocal] = 0
+	}
+}
+
+type localModelList struct {
+	Data []localModel `json:"data"`
+}
+
+type localModel struct {
+	ID                  string `json:"id"`
+	Object              string `json:"object"`
+	Type                string `json:"type"`
+	Publisher           string `json:"publisher"`
+	Arch                string `json:"arch"`
+	CompatibilityType   string `json:"compatibility_type"`
+	Quantization        string `json:"quantization"`
+	State               string `json:"state"`
+	MaxContextLength    int64  `json:"max_context_length"`
+	LoadedContextLength int64  `json:"loaded_context_length"`
+}
+
+func listLocalModels(modelsEndpoint string) []localModel {
+	res, err := http.Get(modelsEndpoint)
+	if err != nil {
+		slog.Debug("Failed to list local models",
+			"error", err,
+			"endpoint", modelsEndpoint,
+		)
+	}
+	defer res.Body.Close()
+
+	if res.StatusCode != http.StatusOK {
+		slog.Debug("Failed to list local models",
+			"status", res.StatusCode,
+			"endpoint", modelsEndpoint,
+		)
+	}
+
+	var modelList localModelList
+	if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
+		slog.Debug("Failed to list local models",
+			"error", err,
+			"endpoint", modelsEndpoint,
+		)
+	}
+
+	var supportedModels []localModel
+	for _, model := range modelList.Data {
+		if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
+			if model.Object != "model" || model.Type != "llm" {
+				slog.Debug("Skipping unsupported LMStudio model",
+					"endpoint", modelsEndpoint,
+					"id", model.ID,
+					"object", model.Object,
+					"type", model.Type,
+				)
+
+				continue
+			}
+		}
+
+		supportedModels = append(supportedModels, model)
+	}
+
+	return supportedModels
+}
+
+func loadLocalModels(models []localModel) {
+	for i, m := range models {
+		model := convertLocalModel(m)
+		SupportedModels[model.ID] = model
+
+		if i == 1 || m.State == "loaded" {
+			viper.SetDefault("agents.coder.model", model.ID)
+			viper.SetDefault("agents.summarizer.model", model.ID)
+			viper.SetDefault("agents.task.model", model.ID)
+			viper.SetDefault("agents.title.model", model.ID)
+		}
+	}
+}
+
+func convertLocalModel(model localModel) Model {
+	return Model{
+		ID:                  ModelID("local." + model.ID),
+		Name:                friendlyModelName(model.ID),
+		Provider:            ProviderLocal,
+		APIModel:            model.ID,
+		ContextWindow:       cmp.Or(model.LoadedContextLength, 4096),
+		DefaultMaxTokens:    cmp.Or(model.LoadedContextLength, 4096),
+		CanReason:           true,
+		SupportsAttachments: true,
+	}
+}
+
+var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
+
+func friendlyModelName(modelID string) string {
+	match := modelInfoRegex.FindStringSubmatch(modelID)
+	if match == nil {
+		return modelID
+	}
+
+	capitalize := func(s string) string {
+		if s == "" {
+			return ""
+		}
+		runes := []rune(s)
+		runes[0] = unicode.ToUpper(runes[0])
+		return string(runes)
+	}
+
+	family := capitalize(match[1])
+	version := ""
+	label := ""
+
+	if len(match) > 2 && match[2] != "" {
+		version = strings.ToUpper(match[2])
+	}
+
+	if len(match) > 3 && match[3] != "" {
+		label = capitalize(match[3])
+	}
+
+	var parts []string
+	if family != "" {
+		parts = append(parts, family)
+	}
+	if version != "" {
+		parts = append(parts, version)
+	}
+	if label != "" {
+		parts = append(parts, label)
+	}
+
+	return strings.Join(parts, " ")
+}

internal/llm/provider/provider.go 🔗

@@ -3,6 +3,7 @@ package provider
 import (
 	"context"
 	"fmt"
+	"os"
 
 	"github.com/opencode-ai/opencode/internal/llm/models"
 	"github.com/opencode-ai/opencode/internal/llm/tools"
@@ -145,7 +146,14 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
 		}, nil
-
+	case models.ProviderLocal:
+		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
+			WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
+		)
+		return &baseProvider[OpenAIClient]{
+			options: clientOptions,
+			client:  newOpenAIClient(clientOptions),
+		}, nil
 	case models.ProviderMock:
 		// TODO: implement mock client for test
 		panic("not implemented")