@@ -715,7 +715,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
provider.WithMaxTokens(maxTokens),
}
- if model.Provider == models.ProviderOpenAI && model.CanReason {
+ if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
opts = append(
opts,
provider.WithOpenAIOptions(
@@ -0,0 +1,191 @@
+package models
+
+import (
+ "cmp"
+ "encoding/json"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "os"
+ "regexp"
+ "strings"
+ "unicode"
+
+ "github.com/spf13/viper"
+)
+
+const (
+ ProviderLocal ModelProvider = "local"
+
+ localModelsPath = "v1/models"
+ lmStudioBetaModelsPath = "api/v0/models"
+)
+
+func init() {
+ if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
+ localEndpoint, err := url.Parse(endpoint)
+ if err != nil {
+ slog.Debug("Failed to parse local endpoint",
+ "error", err,
+ "endpoint", endpoint,
+ )
+ return
+ }
+
+ load := func(url *url.URL, path string) []localModel {
+ url.Path = path
+ return listLocalModels(url.String())
+ }
+
+ models := load(localEndpoint, lmStudioBetaModelsPath)
+
+ if len(models) == 0 {
+ models = load(localEndpoint, localModelsPath)
+ }
+
+ if len(models) == 0 {
+ slog.Debug("No local models found",
+ "endpoint", endpoint,
+ )
+ return
+ }
+
+ loadLocalModels(models)
+
+ viper.SetDefault("providers.local.apiKey", "dummy")
+ ProviderPopularity[ProviderLocal] = 0
+ }
+}
+
+type localModelList struct {
+ Data []localModel `json:"data"`
+}
+
+type localModel struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Type string `json:"type"`
+ Publisher string `json:"publisher"`
+ Arch string `json:"arch"`
+ CompatibilityType string `json:"compatibility_type"`
+ Quantization string `json:"quantization"`
+ State string `json:"state"`
+ MaxContextLength int64 `json:"max_context_length"`
+ LoadedContextLength int64 `json:"loaded_context_length"`
+}
+
+func listLocalModels(modelsEndpoint string) []localModel {
+ res, err := http.Get(modelsEndpoint)
+ if err != nil {
+ slog.Debug("Failed to list local models",
+ "error", err,
+ "endpoint", modelsEndpoint,
+ )
+ }
+ defer res.Body.Close()
+
+ if res.StatusCode != http.StatusOK {
+ slog.Debug("Failed to list local models",
+ "status", res.StatusCode,
+ "endpoint", modelsEndpoint,
+ )
+ }
+
+ var modelList localModelList
+ if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
+ slog.Debug("Failed to list local models",
+ "error", err,
+ "endpoint", modelsEndpoint,
+ )
+ }
+
+ var supportedModels []localModel
+ for _, model := range modelList.Data {
+ if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
+ if model.Object != "model" || model.Type != "llm" {
+ slog.Debug("Skipping unsupported LMStudio model",
+ "endpoint", modelsEndpoint,
+ "id", model.ID,
+ "object", model.Object,
+ "type", model.Type,
+ )
+
+ continue
+ }
+ }
+
+ supportedModels = append(supportedModels, model)
+ }
+
+ return supportedModels
+}
+
+func loadLocalModels(models []localModel) {
+ for i, m := range models {
+ model := convertLocalModel(m)
+ SupportedModels[model.ID] = model
+
+ if i == 1 || m.State == "loaded" {
+ viper.SetDefault("agents.coder.model", model.ID)
+ viper.SetDefault("agents.summarizer.model", model.ID)
+ viper.SetDefault("agents.task.model", model.ID)
+ viper.SetDefault("agents.title.model", model.ID)
+ }
+ }
+}
+
+func convertLocalModel(model localModel) Model {
+ return Model{
+ ID: ModelID("local." + model.ID),
+ Name: friendlyModelName(model.ID),
+ Provider: ProviderLocal,
+ APIModel: model.ID,
+ ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
+ DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
+ CanReason: true,
+ SupportsAttachments: true,
+ }
+}
+
+var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
+
+func friendlyModelName(modelID string) string {
+ match := modelInfoRegex.FindStringSubmatch(modelID)
+ if match == nil {
+ return modelID
+ }
+
+ capitalize := func(s string) string {
+ if s == "" {
+ return ""
+ }
+ runes := []rune(s)
+ runes[0] = unicode.ToUpper(runes[0])
+ return string(runes)
+ }
+
+ family := capitalize(match[1])
+ version := ""
+ label := ""
+
+ if len(match) > 2 && match[2] != "" {
+ version = strings.ToUpper(match[2])
+ }
+
+ if len(match) > 3 && match[3] != "" {
+ label = capitalize(match[3])
+ }
+
+ var parts []string
+ if family != "" {
+ parts = append(parts, family)
+ }
+ if version != "" {
+ parts = append(parts, version)
+ }
+ if label != "" {
+ parts = append(parts, label)
+ }
+
+ return strings.Join(parts, " ")
+}
@@ -3,6 +3,7 @@ package provider
import (
"context"
"fmt"
+ "os"
"github.com/opencode-ai/opencode/internal/llm/models"
"github.com/opencode-ai/opencode/internal/llm/tools"
@@ -145,7 +146,14 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
-
+ case models.ProviderLocal:
+ clientOptions.openaiOptions = append(clientOptions.openaiOptions,
+ WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
+ )
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
case models.ProviderMock:
// TODO: implement mock client for test
panic("not implemented")