diff --git a/AGENTS.md b/AGENTS.md
index 9203f930028e00e54849f1963328950ee2901452..d0a64123d75c6494362d23c32cfd5b35bf18651f 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -27,3 +27,5 @@
./bin/shelley -config /exe.dev/shelley.json -db /tmp/shelley-test.db serve -port 8002
```
Then use browser tools to navigate to http://localhost:8002/ and interact with the UI.
+13. NEVER use alert(), confirm(), or prompt(). Use proper UI components like tooltips, modals, or toasts instead.
+14. SQL migrations and frontend changes require rebuilding the binary (`make build` or `go generate ./... && cd ui && pnpm run build`).
diff --git a/cmd/shelley/main.go b/cmd/shelley/main.go
index f0e450a4a344c352d442b30cbdfb7a0cf2a04baf..34c17cb6ac9e31cb3335a6747de1b372f9128c83 100644
--- a/cmd/shelley/main.go
+++ b/cmd/shelley/main.go
@@ -101,7 +101,7 @@ func runServe(global GlobalConfig, args []string) {
// Build LLM configuration
llmConfig := buildLLMConfig(logger, global.ConfigPath, global.TerminalURL, global.DefaultModel, database)
- // Initialize LLM service manager
+ // Initialize LLM service manager (includes custom model support via database)
llmManager := server.NewLLMServiceManager(llmConfig)
// Log available models
diff --git a/db/db.go b/db/db.go
index d95a2baf0f718203d33f842dea9c7af0def7a032..9f7e3fdd71d1cfe13a88407d09abdaeaba175476 100644
--- a/db/db.go
+++ b/db/db.go
@@ -115,6 +115,20 @@ func (db *DB) Migrate(ctx context.Context) error {
// Sort migrations by number
sort.Strings(migrations)
+ // Check for duplicate migration numbers
+ seenNumbers := make(map[string]string) // number -> filename
+ for _, migration := range migrations {
+ matches := migrationPattern.FindStringSubmatch(migration)
+ if len(matches) < 2 {
+ continue
+ }
+ num := matches[1]
+ if existing, ok := seenNumbers[num]; ok {
+ return fmt.Errorf("duplicate migration number %s: %s and %s", num, existing, migration)
+ }
+ seenNumbers[num] = migration
+ }
+
// Get executed migrations
executedMigrations := make(map[int]bool)
var tableName string
@@ -850,3 +864,68 @@ func reconstructRequestBody(ctx context.Context, q *generated.Queries, requestID
*result = parentBody[:prefixLen] + suffix
return nil
}
+
+// GetModels returns all models from the database
+func (db *DB) GetModels(ctx context.Context) ([]generated.Model, error) {
+ var models []generated.Model
+ err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
+ q := generated.New(rx.Conn())
+ var err error
+ models, err = q.GetModels(ctx)
+ return err
+ })
+ return models, err
+}
+
+// GetModel returns a model by ID
+func (db *DB) GetModel(ctx context.Context, modelID string) (*generated.Model, error) {
+ var model generated.Model
+ err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
+ q := generated.New(rx.Conn())
+ var err error
+ model, err = q.GetModel(ctx, modelID)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &model, nil
+}
+
+// CreateModel creates a new model
+func (db *DB) CreateModel(ctx context.Context, params generated.CreateModelParams) (*generated.Model, error) {
+ var model generated.Model
+ err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ var err error
+ model, err = q.CreateModel(ctx, params)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &model, nil
+}
+
+// UpdateModel updates a model
+func (db *DB) UpdateModel(ctx context.Context, params generated.UpdateModelParams) (*generated.Model, error) {
+ var model generated.Model
+ err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ var err error
+ model, err = q.UpdateModel(ctx, params)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &model, nil
+}
+
+// DeleteModel deletes a model
+func (db *DB) DeleteModel(ctx context.Context, modelID string) error {
+ return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ return q.DeleteModel(ctx, modelID)
+ })
+}
diff --git a/db/generated/llm_requests.sql.go b/db/generated/llm_requests.sql.go
index 599d6023040b91d281b84f32b1d87048be6e1ac5..b6ccb368f571721fc9930fec8f3efba9da333e1a 100644
--- a/db/generated/llm_requests.sql.go
+++ b/db/generated/llm_requests.sql.go
@@ -150,22 +150,24 @@ func (q *Queries) InsertLLMRequest(ctx context.Context, arg InsertLLMRequestPara
}
const listRecentLLMRequests = `-- name: ListRecentLLMRequests :many
-SELECT
- id,
- conversation_id,
- model,
- provider,
- url,
- LENGTH(request_body) as request_body_length,
- LENGTH(response_body) as response_body_length,
- status_code,
- error,
- duration_ms,
- created_at,
- prefix_request_id,
- prefix_length
-FROM llm_requests
-ORDER BY id DESC
+SELECT
+ r.id,
+ r.conversation_id,
+ r.model,
+ m.display_name as model_display_name,
+ r.provider,
+ r.url,
+ LENGTH(r.request_body) as request_body_length,
+ LENGTH(r.response_body) as response_body_length,
+ r.status_code,
+ r.error,
+ r.duration_ms,
+ r.created_at,
+ r.prefix_request_id,
+ r.prefix_length
+FROM llm_requests r
+LEFT JOIN models m ON r.model = m.model_id
+ORDER BY r.id DESC
LIMIT ?
`
@@ -173,6 +175,7 @@ type ListRecentLLMRequestsRow struct {
ID int64 `json:"id"`
ConversationID *string `json:"conversation_id"`
Model string `json:"model"`
+ ModelDisplayName *string `json:"model_display_name"`
Provider string `json:"provider"`
Url string `json:"url"`
RequestBodyLength *int64 `json:"request_body_length"`
@@ -198,6 +201,7 @@ func (q *Queries) ListRecentLLMRequests(ctx context.Context, limit int64) ([]Lis
&i.ID,
&i.ConversationID,
&i.Model,
+ &i.ModelDisplayName,
&i.Provider,
&i.Url,
&i.RequestBodyLength,
diff --git a/db/generated/models.go b/db/generated/models.go
index 4bd0ad9a39e9e498850b1886f7aeb214291b8c49..4fcfda92a6481fb51e28ddd27f9cf5b6c669242f 100644
--- a/db/generated/models.go
+++ b/db/generated/models.go
@@ -52,3 +52,16 @@ type Migration struct {
MigrationName string `json:"migration_name"`
ExecutedAt *time.Time `json:"executed_at"`
}
+
+type Model struct {
+ ModelID string `json:"model_id"`
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ ApiKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
diff --git a/db/generated/models.sql.go b/db/generated/models.sql.go
new file mode 100644
index 0000000000000000000000000000000000000000..2abe2139f8e27c9c386a7c6351908063cb96b8f1
--- /dev/null
+++ b/db/generated/models.sql.go
@@ -0,0 +1,175 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.30.0
+// source: models.sql
+
+package generated
+
+import (
+ "context"
+)
+
+const createModel = `-- name: CreateModel :one
+INSERT INTO models (model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags)
+VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+RETURNING model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at
+`
+
+type CreateModelParams struct {
+ ModelID string `json:"model_id"`
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ ApiKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"`
+}
+
+func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) {
+ row := q.db.QueryRowContext(ctx, createModel,
+ arg.ModelID,
+ arg.DisplayName,
+ arg.ProviderType,
+ arg.Endpoint,
+ arg.ApiKey,
+ arg.ModelName,
+ arg.MaxTokens,
+ arg.Tags,
+ )
+ var i Model
+ err := row.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const deleteModel = `-- name: DeleteModel :exec
+DELETE FROM models WHERE model_id = ?
+`
+
+func (q *Queries) DeleteModel(ctx context.Context, modelID string) error {
+ _, err := q.db.ExecContext(ctx, deleteModel, modelID)
+ return err
+}
+
+const getModel = `-- name: GetModel :one
+SELECT model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at FROM models WHERE model_id = ?
+`
+
+func (q *Queries) GetModel(ctx context.Context, modelID string) (Model, error) {
+ row := q.db.QueryRowContext(ctx, getModel, modelID)
+ var i Model
+ err := row.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const getModels = `-- name: GetModels :many
+SELECT model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at FROM models ORDER BY created_at ASC
+`
+
+func (q *Queries) GetModels(ctx context.Context) ([]Model, error) {
+ rows, err := q.db.QueryContext(ctx, getModels)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Model{}
+ for rows.Next() {
+ var i Model
+ if err := rows.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const updateModel = `-- name: UpdateModel :one
+UPDATE models
+SET display_name = ?,
+ provider_type = ?,
+ endpoint = ?,
+ api_key = ?,
+ model_name = ?,
+ max_tokens = ?,
+ tags = ?,
+ updated_at = CURRENT_TIMESTAMP
+WHERE model_id = ?
+RETURNING model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at
+`
+
+type UpdateModelParams struct {
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ ApiKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"`
+ ModelID string `json:"model_id"`
+}
+
+func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) {
+ row := q.db.QueryRowContext(ctx, updateModel,
+ arg.DisplayName,
+ arg.ProviderType,
+ arg.Endpoint,
+ arg.ApiKey,
+ arg.ModelName,
+ arg.MaxTokens,
+ arg.Tags,
+ arg.ModelID,
+ )
+ var i Model
+ err := row.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
diff --git a/db/query/llm_requests.sql b/db/query/llm_requests.sql
index 3f1b8c10db355687a4f1859e5522bf9fbc59655f..fa86bd8a447d0f4ba45fdd0600d41f6ad0a40174 100644
--- a/db/query/llm_requests.sql
+++ b/db/query/llm_requests.sql
@@ -24,22 +24,24 @@ LIMIT 1;
SELECT * FROM llm_requests WHERE id = ?;
-- name: ListRecentLLMRequests :many
-SELECT
- id,
- conversation_id,
- model,
- provider,
- url,
- LENGTH(request_body) as request_body_length,
- LENGTH(response_body) as response_body_length,
- status_code,
- error,
- duration_ms,
- created_at,
- prefix_request_id,
- prefix_length
-FROM llm_requests
-ORDER BY id DESC
+SELECT
+ r.id,
+ r.conversation_id,
+ r.model,
+ m.display_name as model_display_name,
+ r.provider,
+ r.url,
+ LENGTH(r.request_body) as request_body_length,
+ LENGTH(r.response_body) as response_body_length,
+ r.status_code,
+ r.error,
+ r.duration_ms,
+ r.created_at,
+ r.prefix_request_id,
+ r.prefix_length
+FROM llm_requests r
+LEFT JOIN models m ON r.model = m.model_id
+ORDER BY r.id DESC
LIMIT ?;
-- name: GetLLMRequestBody :one
diff --git a/db/query/models.sql b/db/query/models.sql
new file mode 100644
index 0000000000000000000000000000000000000000..1815da9b7ca6fc4eee337caefbb2c8a020eaf197
--- /dev/null
+++ b/db/query/models.sql
@@ -0,0 +1,26 @@
+-- name: GetModels :many
+SELECT * FROM models ORDER BY created_at ASC;
+
+-- name: GetModel :one
+SELECT * FROM models WHERE model_id = ?;
+
+-- name: CreateModel :one
+INSERT INTO models (model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags)
+VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+RETURNING *;
+
+-- name: UpdateModel :one
+UPDATE models
+SET display_name = ?,
+ provider_type = ?,
+ endpoint = ?,
+ api_key = ?,
+ model_name = ?,
+ max_tokens = ?,
+ tags = ?,
+ updated_at = CURRENT_TIMESTAMP
+WHERE model_id = ?
+RETURNING *;
+
+-- name: DeleteModel :exec
+DELETE FROM models WHERE model_id = ?;
diff --git a/db/schema/012-custom-models.sql b/db/schema/012-custom-models.sql
new file mode 100644
index 0000000000000000000000000000000000000000..358a0e458ba281440e5683abe7508ba6b8566582
--- /dev/null
+++ b/db/schema/012-custom-models.sql
@@ -0,0 +1,15 @@
+-- Models table
+-- Stores user-configured LLM models with API keys
+
+CREATE TABLE models (
+ model_id TEXT PRIMARY KEY,
+ display_name TEXT NOT NULL,
+ provider_type TEXT NOT NULL CHECK (provider_type IN ('anthropic', 'openai', 'openai-responses', 'gemini')),
+ endpoint TEXT NOT NULL,
+ api_key TEXT NOT NULL,
+ model_name TEXT NOT NULL, -- The actual model name sent to the API (e.g., "claude-sonnet-4-5-20250514")
+ max_tokens INTEGER NOT NULL DEFAULT 200000,
+ tags TEXT NOT NULL DEFAULT '', -- Comma-separated tags (e.g., "slug" for slug generation)
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
index fa0504d2f16b7e6a18ac2426cf1769ebcb75c442..4349630d766a25e6b5c9aabac6b7c4b39162917b 100644
--- a/llm/oai/oai.go
+++ b/llm/oai/oai.go
@@ -292,6 +292,13 @@ var (
APIKeyEnv: OpenAIAPIKeyEnv,
}
+ GPT52Codex = Model{
+ UserName: "gpt-5.2-codex",
+ ModelName: "gpt-5.2-codex",
+ URL: OpenAIURL,
+ APIKeyEnv: OpenAIAPIKeyEnv,
+ }
+
// Skaband-specific model names.
// Provider details (URL and APIKeyEnv) are handled by skaband
Qwen = Model{
@@ -329,6 +336,7 @@ var ModelsRegistry = []Model{
GPT5Mini,
GPT5Nano,
GPT5Codex,
+ GPT52Codex,
O3,
O4Mini,
Gemini25Flash,
@@ -525,22 +533,6 @@ func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
return messages
}
-// requiresMaxCompletionTokens returns true if the model requires max_completion_tokens instead of max_tokens.
-func (m Model) requiresMaxCompletionTokens() bool {
- // Reasoning models always use max_completion_tokens
- if m.IsReasoningModel {
- return true
- }
-
- // GPT-5 series models also require max_completion_tokens
- switch m.ModelName {
- case "gpt-5.1", "gpt-5.1-mini", "gpt-5.1-nano":
- return true
- default:
- return false
- }
-}
-
// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
func fromLLMToolChoice(tc *llm.ToolChoice) any {
if tc == nil {
@@ -812,15 +804,11 @@ func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error
// Create the OpenAI request
req := openai.ChatCompletionRequest{
- Model: model.ModelName,
- Messages: allMessages,
- Tools: tools,
- ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
- }
- if model.requiresMaxCompletionTokens() {
- req.MaxCompletionTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
- } else {
- req.MaxTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
+ Model: model.ModelName,
+ Messages: allMessages,
+ Tools: tools,
+ ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
+ MaxCompletionTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
}
// Construct the full URL for logging and debugging
fullURL := baseURL + "/chat/completions"
diff --git a/llm/oai/oai_responses.go b/llm/oai/oai_responses.go
index 3fedd053f0fd2e3ae02b1a3002bcae2b8b89c6e3..9b58ac9449761191785c6b309b42322a6075b97d 100644
--- a/llm/oai/oai_responses.go
+++ b/llm/oai/oai_responses.go
@@ -340,6 +340,8 @@ func (s *ResponsesService) TokenContextWindow() int {
// Use the same context window logic as the regular service
switch model.ModelName {
+ case "gpt-5.2-codex":
+ return 272000 // 272k for gpt-5.2-codex
case "gpt-5.1-codex":
return 256000 // 256k for gpt-5.1-codex
case "gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14":
diff --git a/llm/oai/oai_responses_test.go b/llm/oai/oai_responses_test.go
index 8e696cee30f953ccf9d34060f96b9028cae59795..d47492b11410004c9da94c0a6c1e62ce001e6693 100644
--- a/llm/oai/oai_responses_test.go
+++ b/llm/oai/oai_responses_test.go
@@ -284,6 +284,7 @@ func TestResponsesServiceTokenContextWindow(t *testing.T) {
model Model
expected int
}{
+ {model: GPT52Codex, expected: 272000},
{model: GPT5Codex, expected: 256000},
{model: GPT41, expected: 200000},
{model: GPT4o, expected: 128000},
diff --git a/llm/oai/oai_test.go b/llm/oai/oai_test.go
index e631e061b9a600856862ca57433dd24126275f84..76e0648be6d69314aea4e6e9d5b34b76473f2a1a 100644
--- a/llm/oai/oai_test.go
+++ b/llm/oai/oai_test.go
@@ -12,106 +12,6 @@ import (
"shelley.exe.dev/llm"
)
-func TestRequiresMaxCompletionTokens(t *testing.T) {
- tests := []struct {
- name string
- model Model
- expected bool
- }{
- {
- name: "GPT-5 requires max_completion_tokens",
- model: GPT5,
- expected: true,
- },
- {
- name: "GPT-5 Mini requires max_completion_tokens",
- model: GPT5Mini,
- expected: true,
- },
- {
- name: "O3 reasoning model requires max_completion_tokens",
- model: O3,
- expected: true,
- },
- {
- name: "O4-mini reasoning model requires max_completion_tokens",
- model: O4Mini,
- expected: true,
- },
- {
- name: "GPT-4.1 uses max_tokens",
- model: GPT41,
- expected: false,
- },
- {
- name: "GPT-4o uses max_tokens",
- model: GPT4o,
- expected: false,
- },
- {
- name: "GPT-4o Mini uses max_tokens",
- model: GPT4oMini,
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := tt.model.requiresMaxCompletionTokens()
- if result != tt.expected {
- t.Errorf("requiresMaxCompletionTokens() = %v, expected %v", result, tt.expected)
- }
- })
- }
-}
-
-func TestRequestParameterGeneration(t *testing.T) {
- // Test that we can generate the correct request structure without making API calls
- tests := []struct {
- name string
- model Model
- expectMaxTokens bool
- expectMaxCompletionTokens bool
- }{
- {
- name: "GPT-5 uses max_completion_tokens",
- model: GPT5,
- expectMaxTokens: false,
- expectMaxCompletionTokens: true,
- },
- {
- name: "GPT-5 Mini uses max_completion_tokens",
- model: GPT5Mini,
- expectMaxTokens: false,
- expectMaxCompletionTokens: true,
- },
- {
- name: "GPT-4.1 uses max_tokens",
- model: GPT41,
- expectMaxTokens: true,
- expectMaxCompletionTokens: false,
- },
- {
- name: "O3 uses max_completion_tokens",
- model: O3,
- expectMaxTokens: false,
- expectMaxCompletionTokens: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- usesMaxCompletionTokens := tt.model.requiresMaxCompletionTokens()
- if tt.expectMaxCompletionTokens && !usesMaxCompletionTokens {
- t.Errorf("Expected model %s to use max_completion_tokens, but it doesn't", tt.model.UserName)
- }
- if tt.expectMaxTokens && usesMaxCompletionTokens {
- t.Errorf("Expected model %s to use max_tokens, but it uses max_completion_tokens", tt.model.UserName)
- }
- })
- }
-}
-
func TestToRoleFromString(t *testing.T) {
tests := []struct {
name string
diff --git a/models/models.go b/models/models.go
index 9bf92e18e81b20af00432f5b8bb1acfcb2d86dbc..503cbab9475af9ae407f4ef1e5556f2a03257d95 100644
--- a/models/models.go
+++ b/models/models.go
@@ -148,47 +148,15 @@ func All() []Model {
},
},
{
- ID: "gpt-5",
+ ID: "gpt-5.2-codex",
Provider: ProviderOpenAI,
- Description: "GPT-5",
+ Description: "GPT-5.2 Codex",
RequiredEnvVars: []string{"OPENAI_API_KEY"},
Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
if config.OpenAIAPIKey == "" {
- return nil, fmt.Errorf("gpt-5 requires OPENAI_API_KEY")
+ return nil, fmt.Errorf("gpt-5.2-codex requires OPENAI_API_KEY")
}
- svc := &oai.Service{Model: oai.GPT5, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
- if url := config.getOpenAIURL(); url != "" {
- svc.ModelURL = url
- }
- return svc, nil
- },
- },
- {
- ID: "gpt-5-nano",
- Provider: ProviderOpenAI,
- Description: "GPT-5 Nano",
- RequiredEnvVars: []string{"OPENAI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.OpenAIAPIKey == "" {
- return nil, fmt.Errorf("gpt-5-nano requires OPENAI_API_KEY")
- }
- svc := &oai.Service{Model: oai.GPT5Nano, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
- if url := config.getOpenAIURL(); url != "" {
- svc.ModelURL = url
- }
- return svc, nil
- },
- },
- {
- ID: "gpt-5.1-codex",
- Provider: ProviderOpenAI,
- Description: "GPT-5.1 Codex (uses Responses API)",
- RequiredEnvVars: []string{"OPENAI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.OpenAIAPIKey == "" {
- return nil, fmt.Errorf("gpt-5.1-codex requires OPENAI_API_KEY")
- }
- svc := &oai.ResponsesService{Model: oai.GPT5Codex, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
+ svc := &oai.ResponsesService{Model: oai.GPT52Codex, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
if url := config.getOpenAIURL(); url != "" {
svc.ModelURL = url
}
@@ -259,38 +227,6 @@ func All() []Model {
return svc, nil
},
},
- {
- ID: "gemini-2.5-pro",
- Provider: ProviderGemini,
- Description: "Gemini 2.5 Pro",
- RequiredEnvVars: []string{"GEMINI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.GeminiAPIKey == "" {
- return nil, fmt.Errorf("gemini-2.5-pro requires GEMINI_API_KEY")
- }
- svc := &gem.Service{APIKey: config.GeminiAPIKey, Model: "gemini-2.5-pro", HTTPC: httpc}
- if url := config.getGeminiURL(); url != "" {
- svc.URL = url
- }
- return svc, nil
- },
- },
- {
- ID: "gemini-2.5-flash",
- Provider: ProviderGemini,
- Description: "Gemini 2.5 Flash",
- RequiredEnvVars: []string{"GEMINI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.GeminiAPIKey == "" {
- return nil, fmt.Errorf("gemini-2.5-flash requires GEMINI_API_KEY")
- }
- svc := &gem.Service{APIKey: config.GeminiAPIKey, Model: "gemini-2.5-flash", HTTPC: httpc}
- if url := config.getGeminiURL(); url != "" {
- svc.URL = url
- }
- return svc, nil
- },
- },
{
ID: "predictable",
Provider: ProviderBuiltIn,
@@ -332,7 +268,8 @@ func Default() Model {
type Manager struct {
services map[string]serviceEntry
logger *slog.Logger
- db *db.DB
+ db *db.DB // for custom models and LLM request recording
+ httpc *http.Client // HTTP client with recording middleware
}
type serviceEntry struct {
@@ -502,6 +439,9 @@ func NewManager(cfg *Config) (*Manager, error) {
httpc = llmhttp.NewClient(nil, nil)
}
+ // Store the HTTP client for use with custom models
+ manager.httpc = httpc
+
for _, model := range All() {
svc, err := model.Factory(cfg, httpc)
if err != nil {
@@ -520,6 +460,34 @@ func NewManager(cfg *Config) (*Manager, error) {
// GetService returns the LLM service for the given model ID, wrapped with logging
func (m *Manager) GetService(modelID string) (llm.Service, error) {
+ // Check custom models first if we have a database
+ if m.db != nil {
+ dbModels, err := m.db.GetModels(context.Background())
+ if err == nil && len(dbModels) > 0 {
+ // Custom models exist - only serve custom models, not built-in ones
+ for _, model := range dbModels {
+ if model.ModelID == modelID {
+ svc := m.createServiceFromModel(&model)
+ if svc != nil {
+ if m.logger != nil {
+ return &loggingService{
+ service: svc,
+ logger: m.logger,
+ modelID: modelID,
+ provider: Provider(model.ProviderType),
+ db: m.db,
+ }, nil
+ }
+ return svc, nil
+ }
+ }
+ }
+ // Custom models exist but this model ID wasn't found among them
+ return nil, fmt.Errorf("unsupported model: %s", modelID)
+ }
+ }
+
+ // No custom models - fall back to built-in models
if entry, ok := m.services[modelID]; ok {
// Wrap with logging if we have a logger
if m.logger != nil {
@@ -538,9 +506,20 @@ func (m *Manager) GetService(modelID string) (llm.Service, error) {
// GetAvailableModels returns a list of available model IDs in the same order as All()
func (m *Manager) GetAvailableModels() []string {
- // Return IDs in the same order as All() for consistency
- all := All()
var ids []string
+
+ // If we have custom models in the database, use ONLY those
+ if m.db != nil {
+ if dbModels, err := m.db.GetModels(context.Background()); err == nil && len(dbModels) > 0 {
+ for _, model := range dbModels {
+ ids = append(ids, model.ModelID)
+ }
+ return ids
+ }
+ }
+
+ // No custom models - fall back to built-in models in the same order as All()
+ all := All()
for _, model := range all {
if _, ok := m.services[model.ID]; ok {
ids = append(ids, model.ID)
@@ -551,6 +530,80 @@ func (m *Manager) GetAvailableModels() []string {
// HasModel reports whether the manager has a service for the given model ID
func (m *Manager) HasModel(modelID string) bool {
+ // Check custom models first
+ if m.db != nil {
+ if model, err := m.db.GetModel(context.Background(), modelID); err == nil && model != nil {
+ return true
+ }
+ }
_, ok := m.services[modelID]
return ok
}
+
+// ModelInfo contains display name and tags for a model
+type ModelInfo struct {
+ DisplayName string
+ Tags string
+}
+
+// GetModelInfo returns the display name and tags for a model
+func (m *Manager) GetModelInfo(modelID string) *ModelInfo {
+ if m.db == nil {
+ return nil
+ }
+ model, err := m.db.GetModel(context.Background(), modelID)
+ if err != nil {
+ return nil
+ }
+ return &ModelInfo{
+ DisplayName: model.DisplayName,
+ Tags: model.Tags,
+ }
+}
+
+// createServiceFromModel creates an LLM service from a database model configuration
+func (m *Manager) createServiceFromModel(model *generated.Model) llm.Service {
+ switch model.ProviderType {
+ case "anthropic":
+ return &ant.Service{
+ APIKey: model.ApiKey,
+ URL: model.Endpoint,
+ Model: model.ModelName,
+ HTTPC: m.httpc,
+ }
+ case "openai":
+ return &oai.Service{
+ APIKey: model.ApiKey,
+ ModelURL: model.Endpoint,
+ Model: oai.Model{
+ ModelName: model.ModelName,
+ URL: model.Endpoint,
+ },
+ MaxTokens: int(model.MaxTokens),
+ HTTPC: m.httpc,
+ }
+ case "openai-responses":
+ return &oai.ResponsesService{
+ APIKey: model.ApiKey,
+ ModelURL: model.Endpoint,
+ Model: oai.Model{
+ ModelName: model.ModelName,
+ URL: model.Endpoint,
+ },
+ MaxTokens: int(model.MaxTokens),
+ HTTPC: m.httpc,
+ }
+ case "gemini":
+ return &gem.Service{
+ APIKey: model.ApiKey,
+ URL: model.Endpoint,
+ Model: model.ModelName,
+ HTTPC: m.httpc,
+ }
+ default:
+ if m.logger != nil {
+ m.logger.Error("Unknown provider type for model", "model_id", model.ModelID, "provider_type", model.ProviderType)
+ }
+ return nil
+ }
+}
diff --git a/models/models_test.go b/models/models_test.go
index fd0c792c69107ecd6536fc6edabcd3867fe3664a..5baf3b2f5ccea33444d7d62f2a1dc4ed802b2831 100644
--- a/models/models_test.go
+++ b/models/models_test.go
@@ -36,7 +36,7 @@ func TestByID(t *testing.T) {
wantNil bool
}{
{id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false},
- {id: "gpt-5", wantID: "gpt-5", wantNil: false},
+ {id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false},
{id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false},
{id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false},
{id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false},
diff --git a/server/cancel_claude_test.go b/server/cancel_claude_test.go
index c6d9344388c0fdd0679cc16f5ef1115508e68e7c..4736e10d8fd72190c5c06882b34ef42599c6a2f2 100644
--- a/server/cancel_claude_test.go
+++ b/server/cancel_claude_test.go
@@ -19,6 +19,7 @@ import (
"shelley.exe.dev/db/generated"
"shelley.exe.dev/llm"
"shelley.exe.dev/llm/ant"
+ "shelley.exe.dev/models"
)
// ClaudeTestHarness extends TestHarness with Claude-specific functionality
@@ -544,6 +545,13 @@ func (m *claudeLLMManager) HasModel(modelID string) bool {
return modelID == "claude" || modelID == "claude-haiku-4.5"
}
+func (m *claudeLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
+ if modelID == "claude-haiku-4.5" {
+ return &models.ModelInfo{DisplayName: "Claude Haiku", Tags: "slug"}
+ }
+ return nil
+}
+
// TestClaudeCancelDuringToolCall tests cancellation during tool execution with Claude
func TestClaudeCancelDuringToolCall(t *testing.T) {
h := NewClaudeTestHarness(t)
diff --git a/server/cancel_test.go b/server/cancel_test.go
index 73fda6b90c05e8d3ef13c98b565d2804a2e91603..6917229091fe93cfb1e4a428f3ce69f710e5de5a 100644
--- a/server/cancel_test.go
+++ b/server/cancel_test.go
@@ -15,6 +15,7 @@ import (
"shelley.exe.dev/db/generated"
"shelley.exe.dev/llm"
"shelley.exe.dev/loop"
+ "shelley.exe.dev/models"
)
// setupTestDB creates a test database
@@ -374,3 +375,7 @@ func (m *testLLMManager) GetAvailableModels() []string {
func (m *testLLMManager) HasModel(modelID string) bool {
return modelID == "predictable"
}
+
+func (m *testLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
diff --git a/server/custom_models.go b/server/custom_models.go
new file mode 100644
index 0000000000000000000000000000000000000000..03116c6b9fbf2226462a7ce36949ba57cfe93c50
--- /dev/null
+++ b/server/custom_models.go
@@ -0,0 +1,404 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "shelley.exe.dev/db/generated"
+ "shelley.exe.dev/llm"
+ "shelley.exe.dev/llm/ant"
+ "shelley.exe.dev/llm/gem"
+ "shelley.exe.dev/llm/oai"
+)
+
+// ModelAPI is the API representation of a model
+type ModelAPI struct {
+ ModelID string `json:"model_id"`
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"` // Comma-separated tags (e.g., "slug" for slug generation)
+}
+
+// CreateModelRequest is the request body for creating a model
+type CreateModelRequest struct {
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"` // Comma-separated tags
+}
+
+// UpdateModelRequest is the request body for updating a model
+type UpdateModelRequest struct {
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"` // Empty string means keep existing
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"` // Comma-separated tags
+}
+
+// TestModelRequest is the request body for testing a model
+type TestModelRequest struct {
+ ModelID string `json:"model_id,omitempty"` // If provided, use stored API key
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+}
+
+func toModelAPI(m generated.Model) ModelAPI {
+ return ModelAPI{
+ ModelID: m.ModelID,
+ DisplayName: m.DisplayName,
+ ProviderType: m.ProviderType,
+ Endpoint: m.Endpoint,
+ APIKey: m.ApiKey,
+ ModelName: m.ModelName,
+ MaxTokens: m.MaxTokens,
+ Tags: m.Tags,
+ }
+}
+
+func (s *Server) handleCustomModels(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ s.handleListModels(w, r)
+ case http.MethodPost:
+ s.handleCreateModel(w, r)
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
+ models, err := s.db.GetModels(r.Context())
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to get models: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ apiModels := make([]ModelAPI, len(models))
+ for i, m := range models {
+ apiModels[i] = toModelAPI(m)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(apiModels)
+}
+
+func (s *Server) handleCreateModel(w http.ResponseWriter, r *http.Request) {
+ var req CreateModelRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.DisplayName == "" || req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
+ http.Error(w, "display_name, provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
+ return
+ }
+
+ // Validate provider type
+ if req.ProviderType != "anthropic" && req.ProviderType != "openai" && req.ProviderType != "openai-responses" && req.ProviderType != "gemini" {
+ http.Error(w, "provider_type must be 'anthropic', 'openai', 'openai-responses', or 'gemini'", http.StatusBadRequest)
+ return
+ }
+
+ // Generate model ID
+ modelID := "custom-" + uuid.New().String()[:8]
+
+ // Default max tokens
+ if req.MaxTokens <= 0 {
+ req.MaxTokens = 200000
+ }
+
+ model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
+ ModelID: modelID,
+ DisplayName: req.DisplayName,
+ ProviderType: req.ProviderType,
+ Endpoint: req.Endpoint,
+ ApiKey: req.APIKey,
+ ModelName: req.ModelName,
+ MaxTokens: req.MaxTokens,
+ Tags: req.Tags,
+ })
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to create model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleCustomModel(w http.ResponseWriter, r *http.Request) {
+ // Extract model ID from URL path: /api/custom-models/{id} or /api/custom-models/{id}/duplicate
+ path := strings.TrimPrefix(r.URL.Path, "/api/custom-models/")
+ if path == "" {
+ http.Error(w, "Invalid model ID", http.StatusBadRequest)
+ return
+ }
+
+ // Check for /duplicate suffix
+ if strings.HasSuffix(path, "/duplicate") {
+ modelID := strings.TrimSuffix(path, "/duplicate")
+ if r.Method == http.MethodPost {
+ s.handleDuplicateModel(w, r, modelID)
+ } else {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ return
+ }
+
+ if strings.Contains(path, "/") {
+ http.Error(w, "Invalid model ID", http.StatusBadRequest)
+ return
+ }
+ modelID := path
+
+ switch r.Method {
+ case http.MethodGet:
+ s.handleGetModel(w, r, modelID)
+ case http.MethodPut:
+ s.handleUpdateModel(w, r, modelID)
+ case http.MethodDelete:
+ s.handleDeleteModel(w, r, modelID)
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+func (s *Server) handleGetModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ model, err := s.db.GetModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to get model: %v", err), http.StatusNotFound)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleUpdateModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ // First, get the existing model to get the current API key if not provided
+ existing, err := s.db.GetModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
+ return
+ }
+
+ var req UpdateModelRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Use existing API key if not provided
+ apiKey := req.APIKey
+ if apiKey == "" {
+ apiKey = existing.ApiKey
+ }
+
+ // Default max tokens
+ if req.MaxTokens <= 0 {
+ req.MaxTokens = 200000
+ }
+
+ model, err := s.db.UpdateModel(r.Context(), generated.UpdateModelParams{
+ DisplayName: req.DisplayName,
+ ProviderType: req.ProviderType,
+ Endpoint: req.Endpoint,
+ ApiKey: apiKey,
+ ModelName: req.ModelName,
+ MaxTokens: req.MaxTokens,
+ Tags: req.Tags,
+ ModelID: modelID,
+ })
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to update model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ err := s.db.DeleteModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to delete model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.WriteHeader(http.StatusNoContent)
+}
+
+// DuplicateModelRequest allows overriding fields when duplicating
+type DuplicateModelRequest struct {
+ DisplayName string `json:"display_name,omitempty"`
+}
+
+func (s *Server) handleDuplicateModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ // Get the source model (including API key)
+ source, err := s.db.GetModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Source model not found: %v", err), http.StatusNotFound)
+ return
+ }
+
+ // Parse optional overrides
+ var req DuplicateModelRequest
+ if r.Body != nil {
+ json.NewDecoder(r.Body).Decode(&req) // Ignore errors - all fields optional
+ }
+
+ // Generate new model ID
+ newModelID := "custom-" + uuid.New().String()[:8]
+
+ // Use provided display name or generate one
+ displayName := req.DisplayName
+ if displayName == "" {
+ displayName = source.DisplayName + " (copy)"
+ }
+
+ // Create the duplicate with the same API key
+ model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
+ ModelID: newModelID,
+ DisplayName: displayName,
+ ProviderType: source.ProviderType,
+ Endpoint: source.Endpoint,
+ ApiKey: source.ApiKey, // Copy the API key!
+ ModelName: source.ModelName,
+ MaxTokens: source.MaxTokens,
+ Tags: "", // Don't copy tags
+ })
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to duplicate model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleTestModel(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req TestModelRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // If model_id is provided and api_key is empty, look up the stored key
+ if req.ModelID != "" && req.APIKey == "" {
+ model, err := s.db.GetModel(r.Context(), req.ModelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
+ return
+ }
+ req.APIKey = model.ApiKey
+ }
+
+ if req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
+ http.Error(w, "provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
+ return
+ }
+
+ // Create the appropriate service based on provider type
+ var service llm.Service
+ switch req.ProviderType {
+ case "anthropic":
+ service = &ant.Service{
+ APIKey: req.APIKey,
+ URL: req.Endpoint,
+ Model: req.ModelName,
+ }
+ case "openai":
+ service = &oai.Service{
+ APIKey: req.APIKey,
+ ModelURL: req.Endpoint,
+ Model: oai.Model{
+ ModelName: req.ModelName,
+ URL: req.Endpoint,
+ },
+ }
+ case "gemini":
+ service = &gem.Service{
+ APIKey: req.APIKey,
+ URL: req.Endpoint,
+ Model: req.ModelName,
+ }
+ case "openai-responses":
+ service = &oai.ResponsesService{
+ APIKey: req.APIKey,
+ Model: oai.Model{
+ ModelName: req.ModelName,
+ URL: req.Endpoint,
+ },
+ }
+ default:
+ http.Error(w, "Invalid provider_type", http.StatusBadRequest)
+ return
+ }
+
+ // Send a simple test request
+ ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
+ defer cancel()
+
+ request := &llm.Request{
+ Messages: []llm.Message{
+ {
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{
+ {Type: llm.ContentTypeText, Text: "Say 'test successful' in exactly two words."},
+ },
+ },
+ },
+ }
+
+ response, err := service.Do(ctx, request)
+ if err != nil {
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": false,
+ "message": fmt.Sprintf("Test failed: %v", err),
+ })
+ return
+ }
+
+ // Check if we got a response
+ if len(response.Content) == 0 || response.Content[0].Text == "" {
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": false,
+ "message": "Test failed: empty response from model",
+ })
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": true,
+ "message": fmt.Sprintf("Test successful! Response: %s", response.Content[0].Text),
+ })
+}
diff --git a/server/debug_handlers.go b/server/debug_handlers.go
index 3ff6e156cf81651209d1a66e92f2d3b775b613e4..00dceea25f620bc0b273804fd281af5548d387b6 100644
--- a/server/debug_handlers.go
+++ b/server/debug_handlers.go
@@ -182,6 +182,8 @@ tr:hover { background: #252525; }
}
.tab-content { display: none; }
.tab-content.active { display: block; }
+.model-display { color: #a5d6ff; }
+.model-id { color: #888; font-size: 11px; }
@@ -228,6 +230,13 @@ function formatDuration(ms) {
return (ms / 1000).toFixed(2) + 's';
}
+function formatModel(model, displayName) {
+ if (displayName) {
+ return '' + displayName + ' (' + model + ')';
+ }
+ return model;
+}
+
function syntaxHighlight(json) {
if (typeof json !== 'string') json = JSON.stringify(json, null, 2);
json = json.replace(/&/g, '&').replace(//g, '>');
@@ -254,7 +263,7 @@ async function loadRequests() {
const data = await resp.json();
renderTable(data);
} catch (e) {
- document.getElementById('requests-body').innerHTML =
+ document.getElementById('requests-body').innerHTML =
'| Error loading requests: ' + e.message + ' |
';
}
}
@@ -269,20 +278,20 @@ function renderTable(requests) {
for (const req of requests) {
const tr = document.createElement('tr');
tr.id = 'row-' + req.id;
-
- const statusClass = req.status_code && req.status_code >= 200 && req.status_code < 300 ? 'success' :
+
+ const statusClass = req.status_code && req.status_code >= 200 && req.status_code < 300 ? 'success' :
(req.status_code ? 'error' : '');
-
+
let prefixInfo = '-';
if (req.prefix_request_id) {
- prefixInfo = 'prefix from #' + req.prefix_request_id +
+ prefixInfo = 'prefix from #' + req.prefix_request_id +
' (' + formatSize(req.prefix_length) + ')';
}
-
+
tr.innerHTML = ` + "`" + `
${req.id} |
${formatDate(req.created_at)} |
- ${req.model} |
+ ${formatModel(req.model, req.model_display_name)} |
${req.provider} |
${req.status_code || '-'}${req.error ? ' ⚠' : ''} |
${formatDuration(req.duration_ms)} |
@@ -302,7 +311,7 @@ async function toggleExpand(id) {
expandedRows.delete(id);
return;
}
-
+
expandedRows.add(id);
const row = document.getElementById('row-' + id);
const expandRow = document.createElement('tr');
@@ -325,7 +334,7 @@ async function toggleExpand(id) {
` + "`" + `;
row.after(expandRow);
-
+
// Load request body
loadBody(id, 'request');
}
@@ -336,9 +345,9 @@ async function loadBody(id, type) {
renderBody(id, type, loadedData[key]);
return;
}
-
+
try {
- const url = type === 'request'
+ const url = type === 'request'
? '/debug/llm_requests/' + id + '/request'
: '/debug/llm_requests/' + id + '/response';
const resp = await fetch(url);
@@ -363,13 +372,13 @@ async function loadBody(id, type) {
function renderBody(id, type, data) {
const container = document.querySelector('#tab-' + type + '-' + id + ' pre');
if (!container) return;
-
+
if (data === null) {
container.className = '';
container.textContent = '(empty)';
return;
}
-
+
container.className = '';
if (typeof data === 'object') {
container.innerHTML = syntaxHighlight(JSON.stringify(data, null, 2));
@@ -382,14 +391,14 @@ function showTab(id, tab) {
// Update tab buttons
const expandRow = document.getElementById('expand-' + id);
if (!expandRow) return;
-
+
expandRow.querySelectorAll('.tab-btn').forEach(btn => {
btn.classList.remove('active');
if (btn.textContent.toLowerCase() === tab) {
btn.classList.add('active');
}
});
-
+
// Update tab content
expandRow.querySelectorAll('.tab-content').forEach(content => {
content.classList.remove('active');
diff --git a/server/handlers.go b/server/handlers.go
index 2d6cd14281f8bb6df4a677624dd184ea460e1bb8..b6cfa46e03137f5e043a3832983d4499cddaf2a9 100644
--- a/server/handlers.go
+++ b/server/handlers.go
@@ -357,30 +357,7 @@ func (s *Server) serveIndexWithInit(w http.ResponseWriter, r *http.Request, fs h
}
// Build initialization data
- type ModelInfo struct {
- ID string `json:"id"`
- Ready bool `json:"ready"`
- MaxContextTokens int `json:"max_context_tokens,omitempty"`
- }
-
- var modelList []ModelInfo
- if s.predictableOnly {
- modelList = append(modelList, ModelInfo{ID: "predictable", Ready: true, MaxContextTokens: 200000})
- } else {
- modelIDs := s.llmManager.GetAvailableModels()
- for _, id := range modelIDs {
- // Skip predictable model unless predictable-only flag is set
- if id == "predictable" {
- continue
- }
- svc, err := s.llmManager.GetService(id)
- maxCtx := 0
- if err == nil && svc != nil {
- maxCtx = svc.TokenContextWindow()
- }
- modelList = append(modelList, ModelInfo{ID: id, Ready: err == nil, MaxContextTokens: maxCtx})
- }
- }
+ modelList := s.getModelList()
// Select default model - use configured default if available, otherwise first ready model
defaultModel := s.defaultModel
@@ -913,6 +890,52 @@ func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(version.GetInfo())
}
+// ModelInfo represents a model in the API response
+type ModelInfo struct {
+ ID string `json:"id"`
+ DisplayName string `json:"display_name,omitempty"`
+ Ready bool `json:"ready"`
+ MaxContextTokens int `json:"max_context_tokens,omitempty"`
+}
+
+// getModelList returns the list of available models
+func (s *Server) getModelList() []ModelInfo {
+ var modelList []ModelInfo
+ if s.predictableOnly {
+ modelList = append(modelList, ModelInfo{ID: "predictable", Ready: true, MaxContextTokens: 200000})
+ } else {
+ modelIDs := s.llmManager.GetAvailableModels()
+ for _, id := range modelIDs {
+ // Skip predictable model unless predictable-only flag is set
+ if id == "predictable" {
+ continue
+ }
+ svc, err := s.llmManager.GetService(id)
+ maxCtx := 0
+ if err == nil && svc != nil {
+ maxCtx = svc.TokenContextWindow()
+ }
+ info := ModelInfo{ID: id, Ready: err == nil, MaxContextTokens: maxCtx}
+ // Add display name from model info
+ if modelInfo := s.llmManager.GetModelInfo(id); modelInfo != nil {
+ info.DisplayName = modelInfo.DisplayName
+ }
+ modelList = append(modelList, info)
+ }
+ }
+ return modelList
+}
+
+// handleModels returns the list of available models
+func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(s.getModelList())
+}
+
// handleArchivedConversations handles GET /api/conversations/archived
func (s *Server) handleArchivedConversations(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
diff --git a/server/server.go b/server/server.go
index d2fe39610827e4baa65ea103e4591781aafba40b..e96f34fb95651f1dc7ef7951c29a6a06817df167 100644
--- a/server/server.go
+++ b/server/server.go
@@ -69,6 +69,7 @@ type LLMProvider interface {
GetService(modelID string) (llm.Service, error)
GetAvailableModels() []string
HasModel(modelID string) bool
+ GetModelInfo(modelID string) *models.ModelInfo
}
// NewLLMServiceManager creates a new LLM service manager from config
@@ -261,6 +262,14 @@ func (s *Server) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/read", s.handleRead) // Serves images
mux.Handle("/api/write-file", http.HandlerFunc(s.handleWriteFile)) // Small response
+ // Custom models API
+ mux.Handle("/api/custom-models", http.HandlerFunc(s.handleCustomModels))
+ mux.Handle("/api/custom-models/", http.HandlerFunc(s.handleCustomModel))
+ mux.Handle("/api/custom-models-test", http.HandlerFunc(s.handleTestModel))
+
+ // Models API (dynamic list refresh)
+ mux.Handle("/api/models", http.HandlerFunc(s.handleModels))
+
// Version endpoints
mux.Handle("GET /version", http.HandlerFunc(s.handleVersion))
mux.Handle("GET /version-check", http.HandlerFunc(s.handleVersionCheck))
diff --git a/slug/slug.go b/slug/slug.go
index c3258786113010c4d0b302b5391987499ac42f4a..87b27d879e3e4025a00449a5b922604f4c9ed298 100644
--- a/slug/slug.go
+++ b/slug/slug.go
@@ -10,15 +10,18 @@ import (
"shelley.exe.dev/db"
"shelley.exe.dev/llm"
+ "shelley.exe.dev/models"
)
// LLMServiceProvider defines the interface for getting LLM services
type LLMServiceProvider interface {
GetService(modelID string) (llm.Service, error)
+ GetAvailableModels() []string
+ GetModelInfo(modelID string) *models.ModelInfo
}
// GenerateSlug generates a slug for a conversation and updates the database
-// If conversationModelID is provided, it will try to use that model first before falling back to the default list
+// If conversationModelID is provided, it will be used as a fallback if no model is tagged with "slug"
func GenerateSlug(ctx context.Context, llmProvider LLMServiceProvider, database *db.DB, logger *slog.Logger, conversationID, userMessage, conversationModelID string) (string, error) {
baseSlug, err := generateSlugText(ctx, llmProvider, logger, userMessage, conversationModelID)
if err != nil {
@@ -53,15 +56,14 @@ func GenerateSlug(ctx context.Context, llmProvider LLMServiceProvider, database
}
// generateSlugText generates a human-readable slug for a conversation based on the user message
-// If conversationModelID is "predictable", it will be used instead of the default preferred models
+// Priority order:
+// 1. If conversationModelID is "predictable", use it
+// 2. Try models tagged with "slug"
+// 3. Fall back to the conversation's model (conversationModelID)
func generateSlugText(ctx context.Context, llmProvider LLMServiceProvider, logger *slog.Logger, userMessage, conversationModelID string) (string, error) {
- // Try different models in order of preference
var llmService llm.Service
var err error
- // Preferred models in order of preference
- preferredModels := []string{"qwen3-coder-fireworks", "gpt5-mini", "gpt-5-thinking-mini", "claude-sonnet-4.5", "predictable"}
-
// If conversation is using predictable model, use it for slug generation too
if conversationModelID == "predictable" {
llmService, err = llmProvider.GetService("predictable")
@@ -72,15 +74,29 @@ func generateSlugText(ctx context.Context, llmProvider LLMServiceProvider, logge
}
}
- // If we didn't get the predictable service, try the preferred models
+ // Try models tagged with "slug"
if llmService == nil {
- for _, model := range preferredModels {
- llmService, err = llmProvider.GetService(model)
- if err == nil {
- logger.Debug("Using preferred model for slug generation", "model", model)
+ for _, modelID := range llmProvider.GetAvailableModels() {
+ info := llmProvider.GetModelInfo(modelID)
+ if info != nil && strings.Contains(info.Tags, "slug") {
+ llmService, err = llmProvider.GetService(modelID)
+ if err == nil {
+ logger.Debug("Using slug-tagged model for slug generation", "model", modelID)
+ } else {
+ logger.Debug("Failed to get slug-tagged model", "model", modelID, "error", err)
+ }
break
}
- logger.Debug("Model not available for slug generation", "model", model, "error", err)
+ }
+ }
+
+ // Fall back to the conversation's model
+ if llmService == nil && conversationModelID != "" && conversationModelID != "predictable" {
+ llmService, err = llmProvider.GetService(conversationModelID)
+ if err == nil {
+ logger.Debug("Using conversation model for slug generation", "model", conversationModelID)
+ } else {
+ logger.Debug("Conversation model not available for slug generation", "model", conversationModelID, "error", err)
}
}
@@ -134,9 +150,6 @@ Respond with only the slug, nothing else.`, userMessage)
return "", fmt.Errorf("generated slug is empty after sanitization")
}
- // Note: We don't check for uniqueness here since we're generating for a new conversation
- // and the database will handle any conflicts
-
return slug, nil
}
diff --git a/slug/slug_test.go b/slug/slug_test.go
index 8b6baf814a5368d3b46ecee0262daefec6d67d13..6ef78739742092889950714aa9b9fe3020b87d7a 100644
--- a/slug/slug_test.go
+++ b/slug/slug_test.go
@@ -9,6 +9,7 @@ import (
"shelley.exe.dev/db"
"shelley.exe.dev/llm"
+ "shelley.exe.dev/models"
)
func TestSanitize(t *testing.T) {
@@ -100,6 +101,14 @@ func (m *MockLLMProvider) GetService(modelID string) (llm.Service, error) {
return m.Service, nil
}
+func (m *MockLLMProvider) GetAvailableModels() []string {
+ return []string{"mock"}
+}
+
+func (m *MockLLMProvider) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// TestGenerateSlug_DatabaseIntegration tests slug generation with actual database conflicts
func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
// Create temporary database
@@ -135,7 +144,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Generate first slug - should succeed with "test-slug"
- slug1, err := GenerateSlug(ctx, mockLLM, database, logger, conv1.ConversationID, "Test message", "")
+ slug1, err := GenerateSlug(ctx, mockLLM, database, logger, conv1.ConversationID, "Test message", "test-model")
if err != nil {
t.Fatalf("Failed to generate first slug: %v", err)
}
@@ -150,7 +159,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Generate second slug - should get "test-slug-1" due to conflict
- slug2, err := GenerateSlug(ctx, mockLLM, database, logger, conv2.ConversationID, "Test message", "")
+ slug2, err := GenerateSlug(ctx, mockLLM, database, logger, conv2.ConversationID, "Test message", "test-model")
if err != nil {
t.Fatalf("Failed to generate second slug: %v", err)
}
@@ -165,7 +174,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Generate third slug - should get "test-slug-2" due to conflict
- slug3, err := GenerateSlug(ctx, mockLLM, database, logger, conv3.ConversationID, "Test message", "")
+ slug3, err := GenerateSlug(ctx, mockLLM, database, logger, conv3.ConversationID, "Test message", "test-model")
if err != nil {
t.Fatalf("Failed to generate third slug: %v", err)
}
@@ -203,6 +212,14 @@ func (m *MockLLMProviderWithError) GetService(modelID string) (llm.Service, erro
return nil, fmt.Errorf("model not available")
}
+func (m *MockLLMProviderWithError) GetAvailableModels() []string {
+ return []string{}
+}
+
+func (m *MockLLMProviderWithError) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// MockLLMProviderWithServiceError provides a mock LLM provider that returns a service with error
type MockLLMProviderWithServiceError struct{}
@@ -210,6 +227,14 @@ func (m *MockLLMProviderWithServiceError) GetService(modelID string) (llm.Servic
return &MockLLMServiceWithError{}, nil
}
+func (m *MockLLMProviderWithServiceError) GetAvailableModels() []string {
+ return []string{"mock"}
+}
+
+func (m *MockLLMProviderWithServiceError) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// TestGenerateSlug_LLMError tests error handling when LLM service fails
func TestGenerateSlug_LLMError(t *testing.T) {
mockLLM := &MockLLMProviderWithServiceError{}
@@ -218,8 +243,8 @@ func TestGenerateSlug_LLMError(t *testing.T) {
Level: slog.LevelWarn,
}))
- // Test that LLM error is properly propagated
- _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "")
+ // Test that LLM error is properly propagated (pass a model ID so we get a service)
+ _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model")
if err == nil {
t.Error("Expected error from LLM service, got nil")
}
@@ -255,7 +280,7 @@ func TestGenerateSlug_EmptyResponse(t *testing.T) {
Level: slog.LevelWarn,
}))
- _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "")
+ _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model")
if err == nil {
t.Error("Expected error for empty LLM response, got nil")
}
@@ -271,6 +296,14 @@ func (m *MockLLMProviderWithEmptyResponse) GetService(modelID string) (llm.Servi
return &MockLLMServiceEmptyResponse{}, nil
}
+func (m *MockLLMProviderWithEmptyResponse) GetAvailableModels() []string {
+ return []string{"mock"}
+}
+
+func (m *MockLLMProviderWithEmptyResponse) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// MockLLMServiceEmptyResponse provides a mock LLM service that returns empty response
type MockLLMServiceEmptyResponse struct{}
@@ -301,7 +334,7 @@ func TestGenerateSlug_SanitizationError(t *testing.T) {
Level: slog.LevelWarn,
}))
- _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "")
+ _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model")
if err == nil {
t.Error("Expected error for empty slug after sanitization, got nil")
}
@@ -357,7 +390,7 @@ func TestGenerateSlug_DatabaseError(t *testing.T) {
}
closedDB.Close()
- _, err = GenerateSlug(ctx, mockLLM, closedDB, logger, "test-conversation-id", "Test message", "")
+ _, err = GenerateSlug(ctx, mockLLM, closedDB, logger, "test-conversation-id", "Test message", "test-model")
if err == nil {
t.Error("Expected database error, got nil")
}
@@ -386,9 +419,9 @@ func TestGenerateSlug_PredictableModel(t *testing.T) {
}
}
-// TestGenerateSlug_PredictableModelFallback tests fallback when predictable model is not available
-func TestGenerateSlug_PredictableModelFallback(t *testing.T) {
- // Mock LLM provider that doesn't have predictable model but has other models
+// TestGenerateSlug_ConversationModelFallback tests fallback to conversation model when no slug-tagged models exist
+func TestGenerateSlug_ConversationModelFallback(t *testing.T) {
+ // Mock LLM provider that doesn't have predictable model but has a conversation model
mockLLM := &MockLLMProviderPredictableFallback{
fallbackService: &MockLLMService{
ResponseText: "fallback-slug",
@@ -399,10 +432,10 @@ func TestGenerateSlug_PredictableModelFallback(t *testing.T) {
Level: slog.LevelDebug,
}))
- // Test that fallback to preferred models works when predictable is not available
- slug, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "predictable")
+ // Test that fallback to conversation model works when no slug-tagged models exist
+ slug, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "my-custom-model")
if err != nil {
- t.Fatalf("Failed to generate slug with fallback: %v", err)
+ t.Fatalf("Failed to generate slug with conversation model fallback: %v", err)
}
if slug != "fallback-slug" {
t.Errorf("Expected 'fallback-slug', got %q", slug)
@@ -420,3 +453,11 @@ func (m *MockLLMProviderPredictableFallback) GetService(modelID string) (llm.Ser
}
return m.fallbackService, nil
}
+
+func (m *MockLLMProviderPredictableFallback) GetAvailableModels() []string {
+ return []string{"my-custom-model"}
+}
+
+func (m *MockLLMProviderPredictableFallback) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
diff --git a/test/server_test.go b/test/server_test.go
index da9de5678f4371e2f2fceb7a8e1bf057944f6d81..2fa0f7101d65459bc6822add6b2c22fb0f102099 100644
--- a/test/server_test.go
+++ b/test/server_test.go
@@ -23,6 +23,7 @@ import (
"shelley.exe.dev/db/generated"
"shelley.exe.dev/llm"
"shelley.exe.dev/loop"
+ "shelley.exe.dev/models"
"shelley.exe.dev/server"
"shelley.exe.dev/slug"
)
@@ -907,6 +908,10 @@ func (m *inspectableLLMManager) HasModel(modelID string) bool {
return modelID == "predictable"
}
+func (m *inspectableLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
func TestVersionEndpoint(t *testing.T) {
// Create temp DB-backed server
ctx := context.Background()
diff --git a/ui/src/App.tsx b/ui/src/App.tsx
index 5e7d38c697ab13e7ce5ab6cf913e6ff5db943110..955847977129e4273b3c4dbaf8698cdb82b01fe2 100644
--- a/ui/src/App.tsx
+++ b/ui/src/App.tsx
@@ -2,6 +2,7 @@ import React, { useState, useEffect, useCallback, useRef } from "react";
import ChatInterface from "./components/ChatInterface";
import ConversationDrawer from "./components/ConversationDrawer";
import CommandPalette from "./components/CommandPalette";
+import ModelsModal from "./components/ModelsModal";
import { Conversation, ConversationWithState, ConversationListUpdate } from "./types";
import { api } from "./services/api";
@@ -65,6 +66,8 @@ function App() {
const [drawerCollapsed, setDrawerCollapsed] = useState(false);
const [commandPaletteOpen, setCommandPaletteOpen] = useState(false);
const [diffViewerTrigger, setDiffViewerTrigger] = useState(0);
+ const [modelsModalOpen, setModelsModalOpen] = useState(false);
+ const [modelsRefreshTrigger, setModelsRefreshTrigger] = useState(0);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
const [subagentUpdate, setSubagentUpdate] = useState(null);
@@ -348,6 +351,7 @@ function App() {
isDrawerCollapsed={drawerCollapsed}
onToggleDrawerCollapse={toggleDrawerCollapsed}
openDiffViewerTrigger={diffViewerTrigger}
+ modelsRefreshTrigger={modelsRefreshTrigger}
/>
@@ -368,9 +372,19 @@ function App() {
setDiffViewerTrigger((prev) => prev + 1);
setCommandPaletteOpen(false);
}}
+ onOpenModelsModal={() => {
+ setModelsModalOpen(true);
+ setCommandPaletteOpen(false);
+ }}
hasCwd={!!(currentConversation?.cwd || mostRecentCwd)}
/>
+ setModelsModalOpen(false)}
+ onModelsChanged={() => setModelsRefreshTrigger((prev) => prev + 1)}
+ />
+
{/* Backdrop for mobile drawer */}
{drawerOpen && (
setDrawerOpen(false)} />
diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx
index a2251a115671add51ca3515784d68391a741b154..8c26f0a4d6e452007462ae7ee1f73a5aba35239d 100644
--- a/ui/src/components/ChatInterface.tsx
+++ b/ui/src/components/ChatInterface.tsx
@@ -394,6 +394,7 @@ interface ChatInterfaceProps {
isDrawerCollapsed?: boolean;
onToggleDrawerCollapse?: () => void;
openDiffViewerTrigger?: number; // increment to trigger opening diff viewer
+ modelsRefreshTrigger?: number; // increment to trigger models list refresh
}
function ChatInterface({
@@ -409,12 +410,15 @@ function ChatInterface({
isDrawerCollapsed,
onToggleDrawerCollapse,
openDiffViewerTrigger,
+ modelsRefreshTrigger,
}: ChatInterfaceProps) {
const [messages, setMessages] = useState([]);
const [loading, setLoading] = useState(true);
const [sending, setSending] = useState(false);
const [error, setError] = useState(null);
- const models = window.__SHELLEY_INIT__?.models || [];
+ const [models, setModels] = useState<
+ Array<{ id: string; display_name?: string; ready: boolean; max_context_tokens?: number }>
+ >(window.__SHELLEY_INIT__?.models || []);
const [selectedModel, setSelectedModelState] = useState(() => {
// First check localStorage for a sticky model preference
const storedModel = localStorage.getItem("shelley_selected_model");
@@ -473,6 +477,26 @@ function ChatInterface({
setCwdInitialized(true);
}
}, [mostRecentCwd, cwdInitialized]);
+
+ // Refresh models list when triggered (e.g., after custom model changes) or when starting new conversation
+ useEffect(() => {
+ // Skip on initial mount with trigger=0, but always refresh when starting a new conversation
+ if (modelsRefreshTrigger === undefined) return;
+ if (modelsRefreshTrigger === 0 && conversationId !== null) return;
+ api
+ .getModels()
+ .then((newModels) => {
+ setModels(newModels);
+ // Also update the global init data so other components see the change
+ if (window.__SHELLEY_INIT__) {
+ window.__SHELLEY_INIT__.models = newModels;
+ }
+ })
+ .catch((err) => {
+ console.error("Failed to refresh models:", err);
+ });
+ }, [modelsRefreshTrigger, conversationId]);
+
const [cwdError, setCwdError] = useState(null);
const [editingModel, setEditingModel] = useState(false);
const [showDirectoryPicker, setShowDirectoryPicker] = useState(false);
@@ -1459,7 +1483,7 @@ function ChatInterface({
>
{models.map((model) => (
))}
@@ -1469,7 +1493,7 @@ function ChatInterface({
onClick={() => setEditingModel(true)}
disabled={sending}
>
- {selectedModel}
+ {models.find((m) => m.id === selectedModel)?.display_name || selectedModel}
)}
diff --git a/ui/src/components/CommandPalette.tsx b/ui/src/components/CommandPalette.tsx
index bf4a8953401531ba2c8a0f171117171c060ae8e1..b97310b810cf28e6841df585275cd2338dc9edc5 100644
--- a/ui/src/components/CommandPalette.tsx
+++ b/ui/src/components/CommandPalette.tsx
@@ -19,6 +19,7 @@ interface CommandPaletteProps {
onNewConversation: () => void;
onSelectConversation: (id: string) => void;
onOpenDiffViewer: () => void;
+ onOpenModelsModal: () => void;
hasCwd: boolean;
}
@@ -64,6 +65,7 @@ function CommandPalette({
onNewConversation,
onSelectConversation,
onOpenDiffViewer,
+ onOpenModelsModal,
hasCwd,
}: CommandPaletteProps) {
const [query, setQuery] = useState("");
@@ -161,8 +163,46 @@ function CommandPalette({
});
}
+ items.push({
+ id: "manage-models",
+ type: "action",
+ title: "Add/Remove Models/Keys",
+ subtitle: "Configure custom AI models and API keys",
+ icon: (
+
+ ),
+ action: () => {
+ onOpenModelsModal();
+ onClose();
+ },
+ keywords: [
+ "model",
+ "key",
+ "api",
+ "configure",
+ "settings",
+ "anthropic",
+ "openai",
+ "gemini",
+ "custom",
+ ],
+ });
+
return items;
- }, [onNewConversation, onOpenDiffViewer, onClose, hasCwd]);
+ }, [onNewConversation, onOpenDiffViewer, onOpenModelsModal, onClose, hasCwd]);
// Convert conversations to command items
const conversationToItem = useCallback(
diff --git a/ui/src/components/Modal.tsx b/ui/src/components/Modal.tsx
index 9b48166a710ef3a5a524ada633a23ac1288bf145..0f2d11816fedafad9910591a33651c78d4ade1fc 100644
--- a/ui/src/components/Modal.tsx
+++ b/ui/src/components/Modal.tsx
@@ -4,10 +4,12 @@ interface ModalProps {
isOpen: boolean;
onClose: () => void;
title: string;
+ titleRight?: React.ReactNode;
children: React.ReactNode;
+ className?: string;
}
-function Modal({ isOpen, onClose, title, children }: ModalProps) {
+function Modal({ isOpen, onClose, title, titleRight, children, className }: ModalProps) {
if (!isOpen) return null;
const handleBackdropClick = (e: React.MouseEvent) => {
@@ -18,10 +20,11 @@ function Modal({ isOpen, onClose, title, children }: ModalProps) {
return (
-
+
{/* Header */}
{title}
+ {titleRight &&
{titleRight}
}