Detailed changes
@@ -27,3 +27,5 @@
./bin/shelley -config /exe.dev/shelley.json -db /tmp/shelley-test.db serve -port 8002
```
Then use browser tools to navigate to http://localhost:8002/ and interact with the UI.
+13. NEVER use alert(), confirm(), or prompt(). Use proper UI components like tooltips, modals, or toasts instead.
+14. SQL migrations and frontend changes require rebuilding the binary (`make build` or `go generate ./... && cd ui && pnpm run build`).
@@ -101,7 +101,7 @@ func runServe(global GlobalConfig, args []string) {
// Build LLM configuration
llmConfig := buildLLMConfig(logger, global.ConfigPath, global.TerminalURL, global.DefaultModel, database)
- // Initialize LLM service manager
+ // Initialize LLM service manager (includes custom model support via database)
llmManager := server.NewLLMServiceManager(llmConfig)
// Log available models
@@ -115,6 +115,20 @@ func (db *DB) Migrate(ctx context.Context) error {
// Sort migrations by number
sort.Strings(migrations)
+ // Check for duplicate migration numbers
+ seenNumbers := make(map[string]string) // number -> filename
+ for _, migration := range migrations {
+ matches := migrationPattern.FindStringSubmatch(migration)
+ if len(matches) < 2 {
+ continue
+ }
+ num := matches[1]
+ if existing, ok := seenNumbers[num]; ok {
+ return fmt.Errorf("duplicate migration number %s: %s and %s", num, existing, migration)
+ }
+ seenNumbers[num] = migration
+ }
+
// Get executed migrations
executedMigrations := make(map[int]bool)
var tableName string
@@ -850,3 +864,68 @@ func reconstructRequestBody(ctx context.Context, q *generated.Queries, requestID
*result = parentBody[:prefixLen] + suffix
return nil
}
+
+// GetModels returns all models from the database
+func (db *DB) GetModels(ctx context.Context) ([]generated.Model, error) {
+ var models []generated.Model
+ err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
+ q := generated.New(rx.Conn())
+ var err error
+ models, err = q.GetModels(ctx)
+ return err
+ })
+ return models, err
+}
+
+// GetModel returns a model by ID
+func (db *DB) GetModel(ctx context.Context, modelID string) (*generated.Model, error) {
+ var model generated.Model
+ err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
+ q := generated.New(rx.Conn())
+ var err error
+ model, err = q.GetModel(ctx, modelID)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &model, nil
+}
+
+// CreateModel creates a new model
+func (db *DB) CreateModel(ctx context.Context, params generated.CreateModelParams) (*generated.Model, error) {
+ var model generated.Model
+ err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ var err error
+ model, err = q.CreateModel(ctx, params)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &model, nil
+}
+
+// UpdateModel updates a model
+func (db *DB) UpdateModel(ctx context.Context, params generated.UpdateModelParams) (*generated.Model, error) {
+ var model generated.Model
+ err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ var err error
+ model, err = q.UpdateModel(ctx, params)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &model, nil
+}
+
+// DeleteModel deletes a model
+func (db *DB) DeleteModel(ctx context.Context, modelID string) error {
+ return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ return q.DeleteModel(ctx, modelID)
+ })
+}
@@ -150,22 +150,24 @@ func (q *Queries) InsertLLMRequest(ctx context.Context, arg InsertLLMRequestPara
}
const listRecentLLMRequests = `-- name: ListRecentLLMRequests :many
-SELECT
- id,
- conversation_id,
- model,
- provider,
- url,
- LENGTH(request_body) as request_body_length,
- LENGTH(response_body) as response_body_length,
- status_code,
- error,
- duration_ms,
- created_at,
- prefix_request_id,
- prefix_length
-FROM llm_requests
-ORDER BY id DESC
+SELECT
+ r.id,
+ r.conversation_id,
+ r.model,
+ m.display_name as model_display_name,
+ r.provider,
+ r.url,
+ LENGTH(r.request_body) as request_body_length,
+ LENGTH(r.response_body) as response_body_length,
+ r.status_code,
+ r.error,
+ r.duration_ms,
+ r.created_at,
+ r.prefix_request_id,
+ r.prefix_length
+FROM llm_requests r
+LEFT JOIN models m ON r.model = m.model_id
+ORDER BY r.id DESC
LIMIT ?
`
@@ -173,6 +175,7 @@ type ListRecentLLMRequestsRow struct {
ID int64 `json:"id"`
ConversationID *string `json:"conversation_id"`
Model string `json:"model"`
+ ModelDisplayName *string `json:"model_display_name"`
Provider string `json:"provider"`
Url string `json:"url"`
RequestBodyLength *int64 `json:"request_body_length"`
@@ -198,6 +201,7 @@ func (q *Queries) ListRecentLLMRequests(ctx context.Context, limit int64) ([]Lis
&i.ID,
&i.ConversationID,
&i.Model,
+ &i.ModelDisplayName,
&i.Provider,
&i.Url,
&i.RequestBodyLength,
@@ -52,3 +52,16 @@ type Migration struct {
MigrationName string `json:"migration_name"`
ExecutedAt *time.Time `json:"executed_at"`
}
+
+type Model struct {
+ ModelID string `json:"model_id"`
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ ApiKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
@@ -0,0 +1,175 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.30.0
+// source: models.sql
+
+package generated
+
+import (
+ "context"
+)
+
+const createModel = `-- name: CreateModel :one
+INSERT INTO models (model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags)
+VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+RETURNING model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at
+`
+
+type CreateModelParams struct {
+ ModelID string `json:"model_id"`
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ ApiKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"`
+}
+
+func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) {
+ row := q.db.QueryRowContext(ctx, createModel,
+ arg.ModelID,
+ arg.DisplayName,
+ arg.ProviderType,
+ arg.Endpoint,
+ arg.ApiKey,
+ arg.ModelName,
+ arg.MaxTokens,
+ arg.Tags,
+ )
+ var i Model
+ err := row.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const deleteModel = `-- name: DeleteModel :exec
+DELETE FROM models WHERE model_id = ?
+`
+
+func (q *Queries) DeleteModel(ctx context.Context, modelID string) error {
+ _, err := q.db.ExecContext(ctx, deleteModel, modelID)
+ return err
+}
+
+const getModel = `-- name: GetModel :one
+SELECT model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at FROM models WHERE model_id = ?
+`
+
+func (q *Queries) GetModel(ctx context.Context, modelID string) (Model, error) {
+ row := q.db.QueryRowContext(ctx, getModel, modelID)
+ var i Model
+ err := row.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const getModels = `-- name: GetModels :many
+SELECT model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at FROM models ORDER BY created_at ASC
+`
+
+func (q *Queries) GetModels(ctx context.Context) ([]Model, error) {
+ rows, err := q.db.QueryContext(ctx, getModels)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Model{}
+ for rows.Next() {
+ var i Model
+ if err := rows.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const updateModel = `-- name: UpdateModel :one
+UPDATE models
+SET display_name = ?,
+ provider_type = ?,
+ endpoint = ?,
+ api_key = ?,
+ model_name = ?,
+ max_tokens = ?,
+ tags = ?,
+ updated_at = CURRENT_TIMESTAMP
+WHERE model_id = ?
+RETURNING model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at
+`
+
+type UpdateModelParams struct {
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ ApiKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"`
+ ModelID string `json:"model_id"`
+}
+
+func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) {
+ row := q.db.QueryRowContext(ctx, updateModel,
+ arg.DisplayName,
+ arg.ProviderType,
+ arg.Endpoint,
+ arg.ApiKey,
+ arg.ModelName,
+ arg.MaxTokens,
+ arg.Tags,
+ arg.ModelID,
+ )
+ var i Model
+ err := row.Scan(
+ &i.ModelID,
+ &i.DisplayName,
+ &i.ProviderType,
+ &i.Endpoint,
+ &i.ApiKey,
+ &i.ModelName,
+ &i.MaxTokens,
+ &i.Tags,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
@@ -24,22 +24,24 @@ LIMIT 1;
SELECT * FROM llm_requests WHERE id = ?;
-- name: ListRecentLLMRequests :many
-SELECT
- id,
- conversation_id,
- model,
- provider,
- url,
- LENGTH(request_body) as request_body_length,
- LENGTH(response_body) as response_body_length,
- status_code,
- error,
- duration_ms,
- created_at,
- prefix_request_id,
- prefix_length
-FROM llm_requests
-ORDER BY id DESC
+SELECT
+ r.id,
+ r.conversation_id,
+ r.model,
+ m.display_name as model_display_name,
+ r.provider,
+ r.url,
+ LENGTH(r.request_body) as request_body_length,
+ LENGTH(r.response_body) as response_body_length,
+ r.status_code,
+ r.error,
+ r.duration_ms,
+ r.created_at,
+ r.prefix_request_id,
+ r.prefix_length
+FROM llm_requests r
+LEFT JOIN models m ON r.model = m.model_id
+ORDER BY r.id DESC
LIMIT ?;
-- name: GetLLMRequestBody :one
@@ -0,0 +1,26 @@
+-- name: GetModels :many
+SELECT * FROM models ORDER BY created_at ASC;
+
+-- name: GetModel :one
+SELECT * FROM models WHERE model_id = ?;
+
+-- name: CreateModel :one
+INSERT INTO models (model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags)
+VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+RETURNING *;
+
+-- name: UpdateModel :one
+UPDATE models
+SET display_name = ?,
+ provider_type = ?,
+ endpoint = ?,
+ api_key = ?,
+ model_name = ?,
+ max_tokens = ?,
+ tags = ?,
+ updated_at = CURRENT_TIMESTAMP
+WHERE model_id = ?
+RETURNING *;
+
+-- name: DeleteModel :exec
+DELETE FROM models WHERE model_id = ?;
@@ -0,0 +1,15 @@
+-- Models table
+-- Stores user-configured LLM models with API keys
+
+CREATE TABLE models (
+ model_id TEXT PRIMARY KEY,
+ display_name TEXT NOT NULL,
+ provider_type TEXT NOT NULL CHECK (provider_type IN ('anthropic', 'openai', 'openai-responses', 'gemini')),
+ endpoint TEXT NOT NULL,
+ api_key TEXT NOT NULL,
+ model_name TEXT NOT NULL, -- The actual model name sent to the API (e.g., "claude-sonnet-4-5-20250514")
+ max_tokens INTEGER NOT NULL DEFAULT 200000,
+ tags TEXT NOT NULL DEFAULT '', -- Comma-separated tags (e.g., "slug" for slug generation)
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
@@ -292,6 +292,13 @@ var (
APIKeyEnv: OpenAIAPIKeyEnv,
}
+ GPT52Codex = Model{
+ UserName: "gpt-5.2-codex",
+ ModelName: "gpt-5.2-codex",
+ URL: OpenAIURL,
+ APIKeyEnv: OpenAIAPIKeyEnv,
+ }
+
// Skaband-specific model names.
// Provider details (URL and APIKeyEnv) are handled by skaband
Qwen = Model{
@@ -329,6 +336,7 @@ var ModelsRegistry = []Model{
GPT5Mini,
GPT5Nano,
GPT5Codex,
+ GPT52Codex,
O3,
O4Mini,
Gemini25Flash,
@@ -525,22 +533,6 @@ func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
return messages
}
-// requiresMaxCompletionTokens returns true if the model requires max_completion_tokens instead of max_tokens.
-func (m Model) requiresMaxCompletionTokens() bool {
- // Reasoning models always use max_completion_tokens
- if m.IsReasoningModel {
- return true
- }
-
- // GPT-5 series models also require max_completion_tokens
- switch m.ModelName {
- case "gpt-5.1", "gpt-5.1-mini", "gpt-5.1-nano":
- return true
- default:
- return false
- }
-}
-
// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
func fromLLMToolChoice(tc *llm.ToolChoice) any {
if tc == nil {
@@ -812,15 +804,11 @@ func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error
// Create the OpenAI request
req := openai.ChatCompletionRequest{
- Model: model.ModelName,
- Messages: allMessages,
- Tools: tools,
- ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
- }
- if model.requiresMaxCompletionTokens() {
- req.MaxCompletionTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
- } else {
- req.MaxTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
+ Model: model.ModelName,
+ Messages: allMessages,
+ Tools: tools,
+ ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
+ MaxCompletionTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
}
// Construct the full URL for logging and debugging
fullURL := baseURL + "/chat/completions"
@@ -340,6 +340,8 @@ func (s *ResponsesService) TokenContextWindow() int {
// Use the same context window logic as the regular service
switch model.ModelName {
+ case "gpt-5.2-codex":
+ return 272000 // 272k for gpt-5.2-codex
case "gpt-5.1-codex":
return 256000 // 256k for gpt-5.1-codex
case "gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14":
@@ -284,6 +284,7 @@ func TestResponsesServiceTokenContextWindow(t *testing.T) {
model Model
expected int
}{
+ {model: GPT52Codex, expected: 272000},
{model: GPT5Codex, expected: 256000},
{model: GPT41, expected: 200000},
{model: GPT4o, expected: 128000},
@@ -12,106 +12,6 @@ import (
"shelley.exe.dev/llm"
)
-func TestRequiresMaxCompletionTokens(t *testing.T) {
- tests := []struct {
- name string
- model Model
- expected bool
- }{
- {
- name: "GPT-5 requires max_completion_tokens",
- model: GPT5,
- expected: true,
- },
- {
- name: "GPT-5 Mini requires max_completion_tokens",
- model: GPT5Mini,
- expected: true,
- },
- {
- name: "O3 reasoning model requires max_completion_tokens",
- model: O3,
- expected: true,
- },
- {
- name: "O4-mini reasoning model requires max_completion_tokens",
- model: O4Mini,
- expected: true,
- },
- {
- name: "GPT-4.1 uses max_tokens",
- model: GPT41,
- expected: false,
- },
- {
- name: "GPT-4o uses max_tokens",
- model: GPT4o,
- expected: false,
- },
- {
- name: "GPT-4o Mini uses max_tokens",
- model: GPT4oMini,
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := tt.model.requiresMaxCompletionTokens()
- if result != tt.expected {
- t.Errorf("requiresMaxCompletionTokens() = %v, expected %v", result, tt.expected)
- }
- })
- }
-}
-
-func TestRequestParameterGeneration(t *testing.T) {
- // Test that we can generate the correct request structure without making API calls
- tests := []struct {
- name string
- model Model
- expectMaxTokens bool
- expectMaxCompletionTokens bool
- }{
- {
- name: "GPT-5 uses max_completion_tokens",
- model: GPT5,
- expectMaxTokens: false,
- expectMaxCompletionTokens: true,
- },
- {
- name: "GPT-5 Mini uses max_completion_tokens",
- model: GPT5Mini,
- expectMaxTokens: false,
- expectMaxCompletionTokens: true,
- },
- {
- name: "GPT-4.1 uses max_tokens",
- model: GPT41,
- expectMaxTokens: true,
- expectMaxCompletionTokens: false,
- },
- {
- name: "O3 uses max_completion_tokens",
- model: O3,
- expectMaxTokens: false,
- expectMaxCompletionTokens: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- usesMaxCompletionTokens := tt.model.requiresMaxCompletionTokens()
- if tt.expectMaxCompletionTokens && !usesMaxCompletionTokens {
- t.Errorf("Expected model %s to use max_completion_tokens, but it doesn't", tt.model.UserName)
- }
- if tt.expectMaxTokens && usesMaxCompletionTokens {
- t.Errorf("Expected model %s to use max_tokens, but it uses max_completion_tokens", tt.model.UserName)
- }
- })
- }
-}
-
func TestToRoleFromString(t *testing.T) {
tests := []struct {
name string
@@ -148,47 +148,15 @@ func All() []Model {
},
},
{
- ID: "gpt-5",
+ ID: "gpt-5.2-codex",
Provider: ProviderOpenAI,
- Description: "GPT-5",
+ Description: "GPT-5.2 Codex",
RequiredEnvVars: []string{"OPENAI_API_KEY"},
Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
if config.OpenAIAPIKey == "" {
- return nil, fmt.Errorf("gpt-5 requires OPENAI_API_KEY")
+ return nil, fmt.Errorf("gpt-5.2-codex requires OPENAI_API_KEY")
}
- svc := &oai.Service{Model: oai.GPT5, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
- if url := config.getOpenAIURL(); url != "" {
- svc.ModelURL = url
- }
- return svc, nil
- },
- },
- {
- ID: "gpt-5-nano",
- Provider: ProviderOpenAI,
- Description: "GPT-5 Nano",
- RequiredEnvVars: []string{"OPENAI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.OpenAIAPIKey == "" {
- return nil, fmt.Errorf("gpt-5-nano requires OPENAI_API_KEY")
- }
- svc := &oai.Service{Model: oai.GPT5Nano, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
- if url := config.getOpenAIURL(); url != "" {
- svc.ModelURL = url
- }
- return svc, nil
- },
- },
- {
- ID: "gpt-5.1-codex",
- Provider: ProviderOpenAI,
- Description: "GPT-5.1 Codex (uses Responses API)",
- RequiredEnvVars: []string{"OPENAI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.OpenAIAPIKey == "" {
- return nil, fmt.Errorf("gpt-5.1-codex requires OPENAI_API_KEY")
- }
- svc := &oai.ResponsesService{Model: oai.GPT5Codex, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
+ svc := &oai.ResponsesService{Model: oai.GPT52Codex, APIKey: config.OpenAIAPIKey, HTTPC: httpc}
if url := config.getOpenAIURL(); url != "" {
svc.ModelURL = url
}
@@ -259,38 +227,6 @@ func All() []Model {
return svc, nil
},
},
- {
- ID: "gemini-2.5-pro",
- Provider: ProviderGemini,
- Description: "Gemini 2.5 Pro",
- RequiredEnvVars: []string{"GEMINI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.GeminiAPIKey == "" {
- return nil, fmt.Errorf("gemini-2.5-pro requires GEMINI_API_KEY")
- }
- svc := &gem.Service{APIKey: config.GeminiAPIKey, Model: "gemini-2.5-pro", HTTPC: httpc}
- if url := config.getGeminiURL(); url != "" {
- svc.URL = url
- }
- return svc, nil
- },
- },
- {
- ID: "gemini-2.5-flash",
- Provider: ProviderGemini,
- Description: "Gemini 2.5 Flash",
- RequiredEnvVars: []string{"GEMINI_API_KEY"},
- Factory: func(config *Config, httpc *http.Client) (llm.Service, error) {
- if config.GeminiAPIKey == "" {
- return nil, fmt.Errorf("gemini-2.5-flash requires GEMINI_API_KEY")
- }
- svc := &gem.Service{APIKey: config.GeminiAPIKey, Model: "gemini-2.5-flash", HTTPC: httpc}
- if url := config.getGeminiURL(); url != "" {
- svc.URL = url
- }
- return svc, nil
- },
- },
{
ID: "predictable",
Provider: ProviderBuiltIn,
@@ -332,7 +268,8 @@ func Default() Model {
type Manager struct {
services map[string]serviceEntry
logger *slog.Logger
- db *db.DB
+ db *db.DB // for custom models and LLM request recording
+ httpc *http.Client // HTTP client with recording middleware
}
type serviceEntry struct {
@@ -502,6 +439,9 @@ func NewManager(cfg *Config) (*Manager, error) {
httpc = llmhttp.NewClient(nil, nil)
}
+ // Store the HTTP client for use with custom models
+ manager.httpc = httpc
+
for _, model := range All() {
svc, err := model.Factory(cfg, httpc)
if err != nil {
@@ -520,6 +460,34 @@ func NewManager(cfg *Config) (*Manager, error) {
// GetService returns the LLM service for the given model ID, wrapped with logging
func (m *Manager) GetService(modelID string) (llm.Service, error) {
+ // Check custom models first if we have a database
+ if m.db != nil {
+ dbModels, err := m.db.GetModels(context.Background())
+ if err == nil && len(dbModels) > 0 {
+ // Custom models exist - only serve custom models, not built-in ones
+ for _, model := range dbModels {
+ if model.ModelID == modelID {
+ svc := m.createServiceFromModel(&model)
+ if svc != nil {
+ if m.logger != nil {
+ return &loggingService{
+ service: svc,
+ logger: m.logger,
+ modelID: modelID,
+ provider: Provider(model.ProviderType),
+ db: m.db,
+ }, nil
+ }
+ return svc, nil
+ }
+ }
+ }
+ // Custom models exist but this model ID wasn't found among them
+ return nil, fmt.Errorf("unsupported model: %s", modelID)
+ }
+ }
+
+ // No custom models - fall back to built-in models
if entry, ok := m.services[modelID]; ok {
// Wrap with logging if we have a logger
if m.logger != nil {
@@ -538,9 +506,20 @@ func (m *Manager) GetService(modelID string) (llm.Service, error) {
// GetAvailableModels returns a list of available model IDs in the same order as All()
func (m *Manager) GetAvailableModels() []string {
- // Return IDs in the same order as All() for consistency
- all := All()
var ids []string
+
+ // If we have custom models in the database, use ONLY those
+ if m.db != nil {
+ if dbModels, err := m.db.GetModels(context.Background()); err == nil && len(dbModels) > 0 {
+ for _, model := range dbModels {
+ ids = append(ids, model.ModelID)
+ }
+ return ids
+ }
+ }
+
+ // No custom models - fall back to built-in models in the same order as All()
+ all := All()
for _, model := range all {
if _, ok := m.services[model.ID]; ok {
ids = append(ids, model.ID)
@@ -551,6 +530,80 @@ func (m *Manager) GetAvailableModels() []string {
// HasModel reports whether the manager has a service for the given model ID
func (m *Manager) HasModel(modelID string) bool {
+ // Check custom models first
+ if m.db != nil {
+ if model, err := m.db.GetModel(context.Background(), modelID); err == nil && model != nil {
+ return true
+ }
+ }
_, ok := m.services[modelID]
return ok
}
+
+// ModelInfo contains display name and tags for a model
+type ModelInfo struct {
+ DisplayName string
+ Tags string
+}
+
+// GetModelInfo returns the display name and tags for a model
+func (m *Manager) GetModelInfo(modelID string) *ModelInfo {
+ if m.db == nil {
+ return nil
+ }
+ model, err := m.db.GetModel(context.Background(), modelID)
+ if err != nil {
+ return nil
+ }
+ return &ModelInfo{
+ DisplayName: model.DisplayName,
+ Tags: model.Tags,
+ }
+}
+
+// createServiceFromModel creates an LLM service from a database model configuration
+func (m *Manager) createServiceFromModel(model *generated.Model) llm.Service {
+ switch model.ProviderType {
+ case "anthropic":
+ return &ant.Service{
+ APIKey: model.ApiKey,
+ URL: model.Endpoint,
+ Model: model.ModelName,
+ HTTPC: m.httpc,
+ }
+ case "openai":
+ return &oai.Service{
+ APIKey: model.ApiKey,
+ ModelURL: model.Endpoint,
+ Model: oai.Model{
+ ModelName: model.ModelName,
+ URL: model.Endpoint,
+ },
+ MaxTokens: int(model.MaxTokens),
+ HTTPC: m.httpc,
+ }
+ case "openai-responses":
+ return &oai.ResponsesService{
+ APIKey: model.ApiKey,
+ ModelURL: model.Endpoint,
+ Model: oai.Model{
+ ModelName: model.ModelName,
+ URL: model.Endpoint,
+ },
+ MaxTokens: int(model.MaxTokens),
+ HTTPC: m.httpc,
+ }
+ case "gemini":
+ return &gem.Service{
+ APIKey: model.ApiKey,
+ URL: model.Endpoint,
+ Model: model.ModelName,
+ HTTPC: m.httpc,
+ }
+ default:
+ if m.logger != nil {
+ m.logger.Error("Unknown provider type for model", "model_id", model.ModelID, "provider_type", model.ProviderType)
+ }
+ return nil
+ }
+}
@@ -36,7 +36,7 @@ func TestByID(t *testing.T) {
wantNil bool
}{
{id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false},
- {id: "gpt-5", wantID: "gpt-5", wantNil: false},
+ {id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false},
{id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false},
{id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false},
{id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false},
@@ -19,6 +19,7 @@ import (
"shelley.exe.dev/db/generated"
"shelley.exe.dev/llm"
"shelley.exe.dev/llm/ant"
+ "shelley.exe.dev/models"
)
// ClaudeTestHarness extends TestHarness with Claude-specific functionality
@@ -544,6 +545,13 @@ func (m *claudeLLMManager) HasModel(modelID string) bool {
return modelID == "claude" || modelID == "claude-haiku-4.5"
}
+func (m *claudeLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
+ if modelID == "claude-haiku-4.5" {
+ return &models.ModelInfo{DisplayName: "Claude Haiku", Tags: "slug"}
+ }
+ return nil
+}
+
// TestClaudeCancelDuringToolCall tests cancellation during tool execution with Claude
func TestClaudeCancelDuringToolCall(t *testing.T) {
h := NewClaudeTestHarness(t)
@@ -15,6 +15,7 @@ import (
"shelley.exe.dev/db/generated"
"shelley.exe.dev/llm"
"shelley.exe.dev/loop"
+ "shelley.exe.dev/models"
)
// setupTestDB creates a test database
@@ -374,3 +375,7 @@ func (m *testLLMManager) GetAvailableModels() []string {
func (m *testLLMManager) HasModel(modelID string) bool {
return modelID == "predictable"
}
+
+func (m *testLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
@@ -0,0 +1,404 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "shelley.exe.dev/db/generated"
+ "shelley.exe.dev/llm"
+ "shelley.exe.dev/llm/ant"
+ "shelley.exe.dev/llm/gem"
+ "shelley.exe.dev/llm/oai"
+)
+
+// ModelAPI is the API representation of a model
+type ModelAPI struct {
+ ModelID string `json:"model_id"`
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"` // Comma-separated tags (e.g., "slug" for slug generation)
+}
+
+// CreateModelRequest is the request body for creating a model
+type CreateModelRequest struct {
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"` // Comma-separated tags
+}
+
+// UpdateModelRequest is the request body for updating a model
+type UpdateModelRequest struct {
+ DisplayName string `json:"display_name"`
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"` // Empty string means keep existing
+ ModelName string `json:"model_name"`
+ MaxTokens int64 `json:"max_tokens"`
+ Tags string `json:"tags"` // Comma-separated tags
+}
+
+// TestModelRequest is the request body for testing a model
+type TestModelRequest struct {
+ ModelID string `json:"model_id,omitempty"` // If provided, use stored API key
+ ProviderType string `json:"provider_type"`
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ ModelName string `json:"model_name"`
+}
+
+func toModelAPI(m generated.Model) ModelAPI {
+ return ModelAPI{
+ ModelID: m.ModelID,
+ DisplayName: m.DisplayName,
+ ProviderType: m.ProviderType,
+ Endpoint: m.Endpoint,
+ APIKey: m.ApiKey,
+ ModelName: m.ModelName,
+ MaxTokens: m.MaxTokens,
+ Tags: m.Tags,
+ }
+}
+
+func (s *Server) handleCustomModels(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ s.handleListModels(w, r)
+ case http.MethodPost:
+ s.handleCreateModel(w, r)
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
+ models, err := s.db.GetModels(r.Context())
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to get models: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ apiModels := make([]ModelAPI, len(models))
+ for i, m := range models {
+ apiModels[i] = toModelAPI(m)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(apiModels)
+}
+
+func (s *Server) handleCreateModel(w http.ResponseWriter, r *http.Request) {
+ var req CreateModelRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.DisplayName == "" || req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
+ http.Error(w, "display_name, provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
+ return
+ }
+
+ // Validate provider type
+ if req.ProviderType != "anthropic" && req.ProviderType != "openai" && req.ProviderType != "openai-responses" && req.ProviderType != "gemini" {
+ http.Error(w, "provider_type must be 'anthropic', 'openai', 'openai-responses', or 'gemini'", http.StatusBadRequest)
+ return
+ }
+
+ // Generate model ID
+ modelID := "custom-" + uuid.New().String()[:8]
+
+ // Default max tokens
+ if req.MaxTokens <= 0 {
+ req.MaxTokens = 200000
+ }
+
+ model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
+ ModelID: modelID,
+ DisplayName: req.DisplayName,
+ ProviderType: req.ProviderType,
+ Endpoint: req.Endpoint,
+ ApiKey: req.APIKey,
+ ModelName: req.ModelName,
+ MaxTokens: req.MaxTokens,
+ Tags: req.Tags,
+ })
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to create model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleCustomModel(w http.ResponseWriter, r *http.Request) {
+ // Extract model ID from URL path: /api/custom-models/{id} or /api/custom-models/{id}/duplicate
+ path := strings.TrimPrefix(r.URL.Path, "/api/custom-models/")
+ if path == "" {
+ http.Error(w, "Invalid model ID", http.StatusBadRequest)
+ return
+ }
+
+ // Check for /duplicate suffix
+ if strings.HasSuffix(path, "/duplicate") {
+ modelID := strings.TrimSuffix(path, "/duplicate")
+ if r.Method == http.MethodPost {
+ s.handleDuplicateModel(w, r, modelID)
+ } else {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ return
+ }
+
+ if strings.Contains(path, "/") {
+ http.Error(w, "Invalid model ID", http.StatusBadRequest)
+ return
+ }
+ modelID := path
+
+ switch r.Method {
+ case http.MethodGet:
+ s.handleGetModel(w, r, modelID)
+ case http.MethodPut:
+ s.handleUpdateModel(w, r, modelID)
+ case http.MethodDelete:
+ s.handleDeleteModel(w, r, modelID)
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+func (s *Server) handleGetModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ model, err := s.db.GetModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to get model: %v", err), http.StatusNotFound)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleUpdateModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ // First, get the existing model to get the current API key if not provided
+ existing, err := s.db.GetModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
+ return
+ }
+
+ var req UpdateModelRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Use existing API key if not provided
+ apiKey := req.APIKey
+ if apiKey == "" {
+ apiKey = existing.ApiKey
+ }
+
+ // Default max tokens
+ if req.MaxTokens <= 0 {
+ req.MaxTokens = 200000
+ }
+
+ model, err := s.db.UpdateModel(r.Context(), generated.UpdateModelParams{
+ DisplayName: req.DisplayName,
+ ProviderType: req.ProviderType,
+ Endpoint: req.Endpoint,
+ ApiKey: apiKey,
+ ModelName: req.ModelName,
+ MaxTokens: req.MaxTokens,
+ Tags: req.Tags,
+ ModelID: modelID,
+ })
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to update model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ err := s.db.DeleteModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to delete model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.WriteHeader(http.StatusNoContent)
+}
+
+// DuplicateModelRequest allows overriding fields when duplicating
+type DuplicateModelRequest struct {
+ DisplayName string `json:"display_name,omitempty"`
+}
+
+func (s *Server) handleDuplicateModel(w http.ResponseWriter, r *http.Request, modelID string) {
+ // Get the source model (including API key)
+ source, err := s.db.GetModel(r.Context(), modelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Source model not found: %v", err), http.StatusNotFound)
+ return
+ }
+
+ // Parse optional overrides
+ var req DuplicateModelRequest
+ if r.Body != nil {
+ json.NewDecoder(r.Body).Decode(&req) // Ignore errors - all fields optional
+ }
+
+ // Generate new model ID
+ newModelID := "custom-" + uuid.New().String()[:8]
+
+ // Use provided display name or generate one
+ displayName := req.DisplayName
+ if displayName == "" {
+ displayName = source.DisplayName + " (copy)"
+ }
+
+ // Create the duplicate with the same API key
+ model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
+ ModelID: newModelID,
+ DisplayName: displayName,
+ ProviderType: source.ProviderType,
+ Endpoint: source.Endpoint,
+ ApiKey: source.ApiKey, // Copy the API key!
+ ModelName: source.ModelName,
+ MaxTokens: source.MaxTokens,
+ Tags: "", // Don't copy tags
+ })
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to duplicate model: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ json.NewEncoder(w).Encode(toModelAPI(*model))
+}
+
+func (s *Server) handleTestModel(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req TestModelRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // If model_id is provided and api_key is empty, look up the stored key
+ if req.ModelID != "" && req.APIKey == "" {
+ model, err := s.db.GetModel(r.Context(), req.ModelID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
+ return
+ }
+ req.APIKey = model.ApiKey
+ }
+
+ if req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
+ http.Error(w, "provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
+ return
+ }
+
+ // Create the appropriate service based on provider type
+ var service llm.Service
+ switch req.ProviderType {
+ case "anthropic":
+ service = &ant.Service{
+ APIKey: req.APIKey,
+ URL: req.Endpoint,
+ Model: req.ModelName,
+ }
+ case "openai":
+ service = &oai.Service{
+ APIKey: req.APIKey,
+ ModelURL: req.Endpoint,
+ Model: oai.Model{
+ ModelName: req.ModelName,
+ URL: req.Endpoint,
+ },
+ }
+ case "gemini":
+ service = &gem.Service{
+ APIKey: req.APIKey,
+ URL: req.Endpoint,
+ Model: req.ModelName,
+ }
+ case "openai-responses":
+ service = &oai.ResponsesService{
+ APIKey: req.APIKey,
+ Model: oai.Model{
+ ModelName: req.ModelName,
+ URL: req.Endpoint,
+ },
+ }
+ default:
+ http.Error(w, "Invalid provider_type", http.StatusBadRequest)
+ return
+ }
+
+ // Send a simple test request
+ ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
+ defer cancel()
+
+ request := &llm.Request{
+ Messages: []llm.Message{
+ {
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{
+ {Type: llm.ContentTypeText, Text: "Say 'test successful' in exactly two words."},
+ },
+ },
+ },
+ }
+
+ response, err := service.Do(ctx, request)
+ if err != nil {
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": false,
+ "message": fmt.Sprintf("Test failed: %v", err),
+ })
+ return
+ }
+
+ // Check if we got a response
+ if len(response.Content) == 0 || response.Content[0].Text == "" {
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": false,
+ "message": "Test failed: empty response from model",
+ })
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "success": true,
+ "message": fmt.Sprintf("Test successful! Response: %s", response.Content[0].Text),
+ })
+}
@@ -182,6 +182,8 @@ tr:hover { background: #252525; }
}
.tab-content { display: none; }
.tab-content.active { display: block; }
+.model-display { color: #a5d6ff; }
+.model-id { color: #888; font-size: 11px; }
</style>
</head>
<body>
@@ -228,6 +230,13 @@ function formatDuration(ms) {
return (ms / 1000).toFixed(2) + 's';
}
+function formatModel(model, displayName) {
+ if (displayName) {
+ return '<span class="model-display">' + displayName + '</span> <span class="model-id">(' + model + ')</span>';
+ }
+ return model;
+}
+
function syntaxHighlight(json) {
if (typeof json !== 'string') json = JSON.stringify(json, null, 2);
json = json.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>');
@@ -254,7 +263,7 @@ async function loadRequests() {
const data = await resp.json();
renderTable(data);
} catch (e) {
- document.getElementById('requests-body').innerHTML =
+ document.getElementById('requests-body').innerHTML =
'<tr><td colspan="10" class="error">Error loading requests: ' + e.message + '</td></tr>';
}
}
@@ -269,20 +278,20 @@ function renderTable(requests) {
for (const req of requests) {
const tr = document.createElement('tr');
tr.id = 'row-' + req.id;
-
- const statusClass = req.status_code && req.status_code >= 200 && req.status_code < 300 ? 'success' :
+
+ const statusClass = req.status_code && req.status_code >= 200 && req.status_code < 300 ? 'success' :
(req.status_code ? 'error' : '');
-
+
let prefixInfo = '-';
if (req.prefix_request_id) {
- prefixInfo = '<span class="dedup-info">prefix from #' + req.prefix_request_id +
+ prefixInfo = '<span class="dedup-info">prefix from #' + req.prefix_request_id +
' (' + formatSize(req.prefix_length) + ')</span>';
}
-
+
tr.innerHTML = ` + "`" + `
<td class="mono">${req.id}</td>
<td>${formatDate(req.created_at)}</td>
- <td>${req.model}</td>
+ <td>${formatModel(req.model, req.model_display_name)}</td>
<td>${req.provider}</td>
<td class="${statusClass}">${req.status_code || '-'}${req.error ? ' ⚠' : ''}</td>
<td>${formatDuration(req.duration_ms)}</td>
@@ -302,7 +311,7 @@ async function toggleExpand(id) {
expandedRows.delete(id);
return;
}
-
+
expandedRows.add(id);
const row = document.getElementById('row-' + id);
const expandRow = document.createElement('tr');
@@ -325,7 +334,7 @@ async function toggleExpand(id) {
</td>
` + "`" + `;
row.after(expandRow);
-
+
// Load request body
loadBody(id, 'request');
}
@@ -336,9 +345,9 @@ async function loadBody(id, type) {
renderBody(id, type, loadedData[key]);
return;
}
-
+
try {
- const url = type === 'request'
+ const url = type === 'request'
? '/debug/llm_requests/' + id + '/request'
: '/debug/llm_requests/' + id + '/response';
const resp = await fetch(url);
@@ -363,13 +372,13 @@ async function loadBody(id, type) {
function renderBody(id, type, data) {
const container = document.querySelector('#tab-' + type + '-' + id + ' pre');
if (!container) return;
-
+
if (data === null) {
container.className = '';
container.textContent = '(empty)';
return;
}
-
+
container.className = '';
if (typeof data === 'object') {
container.innerHTML = syntaxHighlight(JSON.stringify(data, null, 2));
@@ -382,14 +391,14 @@ function showTab(id, tab) {
// Update tab buttons
const expandRow = document.getElementById('expand-' + id);
if (!expandRow) return;
-
+
expandRow.querySelectorAll('.tab-btn').forEach(btn => {
btn.classList.remove('active');
if (btn.textContent.toLowerCase() === tab) {
btn.classList.add('active');
}
});
-
+
// Update tab content
expandRow.querySelectorAll('.tab-content').forEach(content => {
content.classList.remove('active');
@@ -357,30 +357,7 @@ func (s *Server) serveIndexWithInit(w http.ResponseWriter, r *http.Request, fs h
}
// Build initialization data
- type ModelInfo struct {
- ID string `json:"id"`
- Ready bool `json:"ready"`
- MaxContextTokens int `json:"max_context_tokens,omitempty"`
- }
-
- var modelList []ModelInfo
- if s.predictableOnly {
- modelList = append(modelList, ModelInfo{ID: "predictable", Ready: true, MaxContextTokens: 200000})
- } else {
- modelIDs := s.llmManager.GetAvailableModels()
- for _, id := range modelIDs {
- // Skip predictable model unless predictable-only flag is set
- if id == "predictable" {
- continue
- }
- svc, err := s.llmManager.GetService(id)
- maxCtx := 0
- if err == nil && svc != nil {
- maxCtx = svc.TokenContextWindow()
- }
- modelList = append(modelList, ModelInfo{ID: id, Ready: err == nil, MaxContextTokens: maxCtx})
- }
- }
+ modelList := s.getModelList()
// Select default model - use configured default if available, otherwise first ready model
defaultModel := s.defaultModel
@@ -913,6 +890,52 @@ func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(version.GetInfo())
}
+// ModelInfo represents a model in the API response
+type ModelInfo struct {
+ ID string `json:"id"`
+ DisplayName string `json:"display_name,omitempty"`
+ Ready bool `json:"ready"`
+ MaxContextTokens int `json:"max_context_tokens,omitempty"`
+}
+
+// getModelList returns the list of available models
+func (s *Server) getModelList() []ModelInfo {
+ var modelList []ModelInfo
+ if s.predictableOnly {
+ modelList = append(modelList, ModelInfo{ID: "predictable", Ready: true, MaxContextTokens: 200000})
+ } else {
+ modelIDs := s.llmManager.GetAvailableModels()
+ for _, id := range modelIDs {
+ // Skip predictable model unless predictable-only flag is set
+ if id == "predictable" {
+ continue
+ }
+ svc, err := s.llmManager.GetService(id)
+ maxCtx := 0
+ if err == nil && svc != nil {
+ maxCtx = svc.TokenContextWindow()
+ }
+ info := ModelInfo{ID: id, Ready: err == nil, MaxContextTokens: maxCtx}
+ // Add display name from model info
+ if modelInfo := s.llmManager.GetModelInfo(id); modelInfo != nil {
+ info.DisplayName = modelInfo.DisplayName
+ }
+ modelList = append(modelList, info)
+ }
+ }
+ return modelList
+}
+
+// handleModels returns the list of available models
+func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(s.getModelList())
+}
+
// handleArchivedConversations handles GET /api/conversations/archived
func (s *Server) handleArchivedConversations(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
@@ -69,6 +69,7 @@ type LLMProvider interface {
GetService(modelID string) (llm.Service, error)
GetAvailableModels() []string
HasModel(modelID string) bool
+ GetModelInfo(modelID string) *models.ModelInfo
}
// NewLLMServiceManager creates a new LLM service manager from config
@@ -261,6 +262,14 @@ func (s *Server) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/read", s.handleRead) // Serves images
mux.Handle("/api/write-file", http.HandlerFunc(s.handleWriteFile)) // Small response
+ // Custom models API
+ mux.Handle("/api/custom-models", http.HandlerFunc(s.handleCustomModels))
+ mux.Handle("/api/custom-models/", http.HandlerFunc(s.handleCustomModel))
+ mux.Handle("/api/custom-models-test", http.HandlerFunc(s.handleTestModel))
+
+ // Models API (dynamic list refresh)
+ mux.Handle("/api/models", http.HandlerFunc(s.handleModels))
+
// Version endpoints
mux.Handle("GET /version", http.HandlerFunc(s.handleVersion))
mux.Handle("GET /version-check", http.HandlerFunc(s.handleVersionCheck))
@@ -10,15 +10,18 @@ import (
"shelley.exe.dev/db"
"shelley.exe.dev/llm"
+ "shelley.exe.dev/models"
)
// LLMServiceProvider defines the interface for getting LLM services
type LLMServiceProvider interface {
GetService(modelID string) (llm.Service, error)
+ GetAvailableModels() []string
+ GetModelInfo(modelID string) *models.ModelInfo
}
// GenerateSlug generates a slug for a conversation and updates the database
-// If conversationModelID is provided, it will try to use that model first before falling back to the default list
+// If conversationModelID is provided, it will be used as a fallback if no model is tagged with "slug"
func GenerateSlug(ctx context.Context, llmProvider LLMServiceProvider, database *db.DB, logger *slog.Logger, conversationID, userMessage, conversationModelID string) (string, error) {
baseSlug, err := generateSlugText(ctx, llmProvider, logger, userMessage, conversationModelID)
if err != nil {
@@ -53,15 +56,14 @@ func GenerateSlug(ctx context.Context, llmProvider LLMServiceProvider, database
}
// generateSlugText generates a human-readable slug for a conversation based on the user message
-// If conversationModelID is "predictable", it will be used instead of the default preferred models
+// Priority order:
+// 1. If conversationModelID is "predictable", use it
+// 2. Try models tagged with "slug"
+// 3. Fall back to the conversation's model (conversationModelID)
func generateSlugText(ctx context.Context, llmProvider LLMServiceProvider, logger *slog.Logger, userMessage, conversationModelID string) (string, error) {
- // Try different models in order of preference
var llmService llm.Service
var err error
- // Preferred models in order of preference
- preferredModels := []string{"qwen3-coder-fireworks", "gpt5-mini", "gpt-5-thinking-mini", "claude-sonnet-4.5", "predictable"}
-
// If conversation is using predictable model, use it for slug generation too
if conversationModelID == "predictable" {
llmService, err = llmProvider.GetService("predictable")
@@ -72,15 +74,29 @@ func generateSlugText(ctx context.Context, llmProvider LLMServiceProvider, logge
}
}
- // If we didn't get the predictable service, try the preferred models
+ // Try models tagged with "slug"
if llmService == nil {
- for _, model := range preferredModels {
- llmService, err = llmProvider.GetService(model)
- if err == nil {
- logger.Debug("Using preferred model for slug generation", "model", model)
+ for _, modelID := range llmProvider.GetAvailableModels() {
+ info := llmProvider.GetModelInfo(modelID)
+ if info != nil && strings.Contains(info.Tags, "slug") {
+ llmService, err = llmProvider.GetService(modelID)
+ if err == nil {
+ logger.Debug("Using slug-tagged model for slug generation", "model", modelID)
+ } else {
+ logger.Debug("Failed to get slug-tagged model", "model", modelID, "error", err)
+ }
break
}
- logger.Debug("Model not available for slug generation", "model", model, "error", err)
+ }
+ }
+
+ // Fall back to the conversation's model
+ if llmService == nil && conversationModelID != "" && conversationModelID != "predictable" {
+ llmService, err = llmProvider.GetService(conversationModelID)
+ if err == nil {
+ logger.Debug("Using conversation model for slug generation", "model", conversationModelID)
+ } else {
+ logger.Debug("Conversation model not available for slug generation", "model", conversationModelID, "error", err)
}
}
@@ -134,9 +150,6 @@ Respond with only the slug, nothing else.`, userMessage)
return "", fmt.Errorf("generated slug is empty after sanitization")
}
- // Note: We don't check for uniqueness here since we're generating for a new conversation
- // and the database will handle any conflicts
-
return slug, nil
}
@@ -9,6 +9,7 @@ import (
"shelley.exe.dev/db"
"shelley.exe.dev/llm"
+ "shelley.exe.dev/models"
)
func TestSanitize(t *testing.T) {
@@ -100,6 +101,14 @@ func (m *MockLLMProvider) GetService(modelID string) (llm.Service, error) {
return m.Service, nil
}
+func (m *MockLLMProvider) GetAvailableModels() []string {
+ return []string{"mock"}
+}
+
+func (m *MockLLMProvider) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// TestGenerateSlug_DatabaseIntegration tests slug generation with actual database conflicts
func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
// Create temporary database
@@ -135,7 +144,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Generate first slug - should succeed with "test-slug"
- slug1, err := GenerateSlug(ctx, mockLLM, database, logger, conv1.ConversationID, "Test message", "")
+ slug1, err := GenerateSlug(ctx, mockLLM, database, logger, conv1.ConversationID, "Test message", "test-model")
if err != nil {
t.Fatalf("Failed to generate first slug: %v", err)
}
@@ -150,7 +159,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Generate second slug - should get "test-slug-1" due to conflict
- slug2, err := GenerateSlug(ctx, mockLLM, database, logger, conv2.ConversationID, "Test message", "")
+ slug2, err := GenerateSlug(ctx, mockLLM, database, logger, conv2.ConversationID, "Test message", "test-model")
if err != nil {
t.Fatalf("Failed to generate second slug: %v", err)
}
@@ -165,7 +174,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Generate third slug - should get "test-slug-2" due to conflict
- slug3, err := GenerateSlug(ctx, mockLLM, database, logger, conv3.ConversationID, "Test message", "")
+ slug3, err := GenerateSlug(ctx, mockLLM, database, logger, conv3.ConversationID, "Test message", "test-model")
if err != nil {
t.Fatalf("Failed to generate third slug: %v", err)
}
@@ -203,6 +212,14 @@ func (m *MockLLMProviderWithError) GetService(modelID string) (llm.Service, erro
return nil, fmt.Errorf("model not available")
}
+func (m *MockLLMProviderWithError) GetAvailableModels() []string {
+ return []string{}
+}
+
+func (m *MockLLMProviderWithError) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// MockLLMProviderWithServiceError provides a mock LLM provider that returns a service with error
type MockLLMProviderWithServiceError struct{}
@@ -210,6 +227,14 @@ func (m *MockLLMProviderWithServiceError) GetService(modelID string) (llm.Servic
return &MockLLMServiceWithError{}, nil
}
+func (m *MockLLMProviderWithServiceError) GetAvailableModels() []string {
+ return []string{"mock"}
+}
+
+func (m *MockLLMProviderWithServiceError) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// TestGenerateSlug_LLMError tests error handling when LLM service fails
func TestGenerateSlug_LLMError(t *testing.T) {
mockLLM := &MockLLMProviderWithServiceError{}
@@ -218,8 +243,8 @@ func TestGenerateSlug_LLMError(t *testing.T) {
Level: slog.LevelWarn,
}))
- // Test that LLM error is properly propagated
- _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "")
+ // Test that LLM error is properly propagated (pass a model ID so we get a service)
+ _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model")
if err == nil {
t.Error("Expected error from LLM service, got nil")
}
@@ -255,7 +280,7 @@ func TestGenerateSlug_EmptyResponse(t *testing.T) {
Level: slog.LevelWarn,
}))
- _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "")
+ _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model")
if err == nil {
t.Error("Expected error for empty LLM response, got nil")
}
@@ -271,6 +296,14 @@ func (m *MockLLMProviderWithEmptyResponse) GetService(modelID string) (llm.Servi
return &MockLLMServiceEmptyResponse{}, nil
}
+func (m *MockLLMProviderWithEmptyResponse) GetAvailableModels() []string {
+ return []string{"mock"}
+}
+
+func (m *MockLLMProviderWithEmptyResponse) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
// MockLLMServiceEmptyResponse provides a mock LLM service that returns empty response
type MockLLMServiceEmptyResponse struct{}
@@ -301,7 +334,7 @@ func TestGenerateSlug_SanitizationError(t *testing.T) {
Level: slog.LevelWarn,
}))
- _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "")
+ _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model")
if err == nil {
t.Error("Expected error for empty slug after sanitization, got nil")
}
@@ -357,7 +390,7 @@ func TestGenerateSlug_DatabaseError(t *testing.T) {
}
closedDB.Close()
- _, err = GenerateSlug(ctx, mockLLM, closedDB, logger, "test-conversation-id", "Test message", "")
+ _, err = GenerateSlug(ctx, mockLLM, closedDB, logger, "test-conversation-id", "Test message", "test-model")
if err == nil {
t.Error("Expected database error, got nil")
}
@@ -386,9 +419,9 @@ func TestGenerateSlug_PredictableModel(t *testing.T) {
}
}
-// TestGenerateSlug_PredictableModelFallback tests fallback when predictable model is not available
-func TestGenerateSlug_PredictableModelFallback(t *testing.T) {
- // Mock LLM provider that doesn't have predictable model but has other models
+// TestGenerateSlug_ConversationModelFallback tests fallback to conversation model when no slug-tagged models exist
+func TestGenerateSlug_ConversationModelFallback(t *testing.T) {
+ // Mock LLM provider that doesn't have predictable model but has a conversation model
mockLLM := &MockLLMProviderPredictableFallback{
fallbackService: &MockLLMService{
ResponseText: "fallback-slug",
@@ -399,10 +432,10 @@ func TestGenerateSlug_PredictableModelFallback(t *testing.T) {
Level: slog.LevelDebug,
}))
- // Test that fallback to preferred models works when predictable is not available
- slug, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "predictable")
+ // Test that fallback to conversation model works when no slug-tagged models exist
+ slug, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "my-custom-model")
if err != nil {
- t.Fatalf("Failed to generate slug with fallback: %v", err)
+ t.Fatalf("Failed to generate slug with conversation model fallback: %v", err)
}
if slug != "fallback-slug" {
t.Errorf("Expected 'fallback-slug', got %q", slug)
@@ -420,3 +453,11 @@ func (m *MockLLMProviderPredictableFallback) GetService(modelID string) (llm.Ser
}
return m.fallbackService, nil
}
+
+func (m *MockLLMProviderPredictableFallback) GetAvailableModels() []string {
+ return []string{"my-custom-model"}
+}
+
+func (m *MockLLMProviderPredictableFallback) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
@@ -23,6 +23,7 @@ import (
"shelley.exe.dev/db/generated"
"shelley.exe.dev/llm"
"shelley.exe.dev/loop"
+ "shelley.exe.dev/models"
"shelley.exe.dev/server"
"shelley.exe.dev/slug"
)
@@ -907,6 +908,10 @@ func (m *inspectableLLMManager) HasModel(modelID string) bool {
return modelID == "predictable"
}
+func (m *inspectableLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
+ return nil
+}
+
func TestVersionEndpoint(t *testing.T) {
// Create temp DB-backed server
ctx := context.Background()
@@ -2,6 +2,7 @@ import React, { useState, useEffect, useCallback, useRef } from "react";
import ChatInterface from "./components/ChatInterface";
import ConversationDrawer from "./components/ConversationDrawer";
import CommandPalette from "./components/CommandPalette";
+import ModelsModal from "./components/ModelsModal";
import { Conversation, ConversationWithState, ConversationListUpdate } from "./types";
import { api } from "./services/api";
@@ -65,6 +66,8 @@ function App() {
const [drawerCollapsed, setDrawerCollapsed] = useState(false);
const [commandPaletteOpen, setCommandPaletteOpen] = useState(false);
const [diffViewerTrigger, setDiffViewerTrigger] = useState(0);
+ const [modelsModalOpen, setModelsModalOpen] = useState(false);
+ const [modelsRefreshTrigger, setModelsRefreshTrigger] = useState(0);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [subagentUpdate, setSubagentUpdate] = useState<Conversation | null>(null);
@@ -348,6 +351,7 @@ function App() {
isDrawerCollapsed={drawerCollapsed}
onToggleDrawerCollapse={toggleDrawerCollapsed}
openDiffViewerTrigger={diffViewerTrigger}
+ modelsRefreshTrigger={modelsRefreshTrigger}
/>
</div>
@@ -368,9 +372,19 @@ function App() {
setDiffViewerTrigger((prev) => prev + 1);
setCommandPaletteOpen(false);
}}
+ onOpenModelsModal={() => {
+ setModelsModalOpen(true);
+ setCommandPaletteOpen(false);
+ }}
hasCwd={!!(currentConversation?.cwd || mostRecentCwd)}
/>
+ <ModelsModal
+ isOpen={modelsModalOpen}
+ onClose={() => setModelsModalOpen(false)}
+ onModelsChanged={() => setModelsRefreshTrigger((prev) => prev + 1)}
+ />
+
{/* Backdrop for mobile drawer */}
{drawerOpen && (
<div className="backdrop hide-on-desktop" onClick={() => setDrawerOpen(false)} />
@@ -394,6 +394,7 @@ interface ChatInterfaceProps {
isDrawerCollapsed?: boolean;
onToggleDrawerCollapse?: () => void;
openDiffViewerTrigger?: number; // increment to trigger opening diff viewer
+ modelsRefreshTrigger?: number; // increment to trigger models list refresh
}
function ChatInterface({
@@ -409,12 +410,15 @@ function ChatInterface({
isDrawerCollapsed,
onToggleDrawerCollapse,
openDiffViewerTrigger,
+ modelsRefreshTrigger,
}: ChatInterfaceProps) {
const [messages, setMessages] = useState<Message[]>([]);
const [loading, setLoading] = useState(true);
const [sending, setSending] = useState(false);
const [error, setError] = useState<string | null>(null);
- const models = window.__SHELLEY_INIT__?.models || [];
+ const [models, setModels] = useState<
+ Array<{ id: string; display_name?: string; ready: boolean; max_context_tokens?: number }>
+ >(window.__SHELLEY_INIT__?.models || []);
const [selectedModel, setSelectedModelState] = useState<string>(() => {
// First check localStorage for a sticky model preference
const storedModel = localStorage.getItem("shelley_selected_model");
@@ -473,6 +477,26 @@ function ChatInterface({
setCwdInitialized(true);
}
}, [mostRecentCwd, cwdInitialized]);
+
+ // Refresh models list when triggered (e.g., after custom model changes) or when starting new conversation
+ useEffect(() => {
+ // Skip on initial mount with trigger=0, but always refresh when starting a new conversation
+ if (modelsRefreshTrigger === undefined) return;
+ if (modelsRefreshTrigger === 0 && conversationId !== null) return;
+ api
+ .getModels()
+ .then((newModels) => {
+ setModels(newModels);
+ // Also update the global init data so other components see the change
+ if (window.__SHELLEY_INIT__) {
+ window.__SHELLEY_INIT__.models = newModels;
+ }
+ })
+ .catch((err) => {
+ console.error("Failed to refresh models:", err);
+ });
+ }, [modelsRefreshTrigger, conversationId]);
+
const [cwdError, setCwdError] = useState<string | null>(null);
const [editingModel, setEditingModel] = useState(false);
const [showDirectoryPicker, setShowDirectoryPicker] = useState(false);
@@ -1459,7 +1483,7 @@ function ChatInterface({
>
{models.map((model) => (
<option key={model.id} value={model.id} disabled={!model.ready}>
- {model.id} {!model.ready ? "(not ready)" : ""}
+ {model.display_name || model.id} {!model.ready ? "(not ready)" : ""}
</option>
))}
</select>
@@ -1469,7 +1493,7 @@ function ChatInterface({
onClick={() => setEditingModel(true)}
disabled={sending}
>
- {selectedModel}
+ {models.find((m) => m.id === selectedModel)?.display_name || selectedModel}
</button>
)}
</div>
@@ -19,6 +19,7 @@ interface CommandPaletteProps {
onNewConversation: () => void;
onSelectConversation: (id: string) => void;
onOpenDiffViewer: () => void;
+ onOpenModelsModal: () => void;
hasCwd: boolean;
}
@@ -64,6 +65,7 @@ function CommandPalette({
onNewConversation,
onSelectConversation,
onOpenDiffViewer,
+ onOpenModelsModal,
hasCwd,
}: CommandPaletteProps) {
const [query, setQuery] = useState("");
@@ -161,8 +163,46 @@ function CommandPalette({
});
}
+ items.push({
+ id: "manage-models",
+ type: "action",
+ title: "Add/Remove Models/Keys",
+ subtitle: "Configure custom AI models and API keys",
+ icon: (
+ <svg fill="none" stroke="currentColor" viewBox="0 0 24 24" width="16" height="16">
+ <path
+ strokeLinecap="round"
+ strokeLinejoin="round"
+ strokeWidth={2}
+ d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
+ />
+ <path
+ strokeLinecap="round"
+ strokeLinejoin="round"
+ strokeWidth={2}
+ d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
+ />
+ </svg>
+ ),
+ action: () => {
+ onOpenModelsModal();
+ onClose();
+ },
+ keywords: [
+ "model",
+ "key",
+ "api",
+ "configure",
+ "settings",
+ "anthropic",
+ "openai",
+ "gemini",
+ "custom",
+ ],
+ });
+
return items;
- }, [onNewConversation, onOpenDiffViewer, onClose, hasCwd]);
+ }, [onNewConversation, onOpenDiffViewer, onOpenModelsModal, onClose, hasCwd]);
// Convert conversations to command items
const conversationToItem = useCallback(
@@ -4,10 +4,12 @@ interface ModalProps {
isOpen: boolean;
onClose: () => void;
title: string;
+ titleRight?: React.ReactNode;
children: React.ReactNode;
+ className?: string;
}
-function Modal({ isOpen, onClose, title, children }: ModalProps) {
+function Modal({ isOpen, onClose, title, titleRight, children, className }: ModalProps) {
if (!isOpen) return null;
const handleBackdropClick = (e: React.MouseEvent) => {
@@ -18,10 +20,11 @@ function Modal({ isOpen, onClose, title, children }: ModalProps) {
return (
<div className="modal-overlay" onClick={handleBackdropClick}>
- <div className="modal">
+ <div className={`modal ${className || ""}`}>
{/* Header */}
<div className="modal-header">
<h2 className="modal-title">{title}</h2>
+ {titleRight && <div className="modal-title-right">{titleRight}</div>}
<button onClick={onClose} className="btn-icon" aria-label="Close modal">
<svg fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path
@@ -0,0 +1,593 @@
+import React, { useState, useEffect, useCallback } from "react";
+import Modal from "./Modal";
+import {
+ customModelsApi,
+ CustomModel,
+ CreateCustomModelRequest,
+ TestCustomModelRequest,
+} from "../services/api";
+
+interface ModelsModalProps {
+ isOpen: boolean;
+ onClose: () => void;
+ onModelsChanged?: () => void;
+}
+
+type ProviderType = "anthropic" | "openai" | "openai-responses" | "gemini";
+
+const DEFAULT_ENDPOINTS: Record<ProviderType, string> = {
+ anthropic: "https://api.anthropic.com/v1/messages",
+ openai: "https://api.openai.com/v1",
+ "openai-responses": "https://api.openai.com/v1",
+ gemini: "https://generativelanguage.googleapis.com/v1beta",
+};
+
+const PROVIDER_LABELS: Record<ProviderType, string> = {
+ anthropic: "Anthropic",
+ openai: "OpenAI (Chat API)",
+ "openai-responses": "OpenAI (Responses API)",
+ gemini: "Google Gemini",
+};
+
+const DEFAULT_MODELS: Record<ProviderType, { name: string; model_name: string }[]> = {
+ anthropic: [
+ { name: "Claude Sonnet 4.5", model_name: "claude-sonnet-4-5" },
+ { name: "Claude Opus 4.5", model_name: "claude-opus-4-5" },
+ { name: "Claude Haiku 4.5", model_name: "claude-haiku-4-5" },
+ ],
+ openai: [{ name: "GPT-5.2", model_name: "gpt-5.2" }],
+ "openai-responses": [{ name: "GPT-5.2 Codex", model_name: "gpt-5.2-codex" }],
+ gemini: [
+ { name: "Gemini 3 Pro", model_name: "gemini-3-pro-preview" },
+ { name: "Gemini 3 Flash", model_name: "gemini-3-flash-preview" },
+ ],
+};
+
+interface FormData {
+ display_name: string;
+ provider_type: ProviderType;
+ endpoint: string;
+ endpoint_custom: boolean;
+ api_key: string;
+ model_name: string;
+ max_tokens: number;
+ tags: string; // Comma-separated tags
+}
+
+const emptyForm: FormData = {
+ display_name: "",
+ provider_type: "anthropic",
+ endpoint: DEFAULT_ENDPOINTS.anthropic,
+ endpoint_custom: false,
+ api_key: "",
+ model_name: "",
+ max_tokens: 200000,
+ tags: "",
+};
+
+function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) {
+ const [models, setModels] = useState<CustomModel[]>([]);
+ const [loading, setLoading] = useState(true);
+ const [error, setError] = useState<string | null>(null);
+ const [builtInModels, setBuiltInModels] = useState<string[]>([]);
+
+ // Form state
+ const [showForm, setShowForm] = useState(false);
+ const [editingModelId, setEditingModelId] = useState<string | null>(null);
+ const [form, setForm] = useState<FormData>(emptyForm);
+
+ // Test state
+ const [testing, setTesting] = useState(false);
+ const [testResult, setTestResult] = useState<{ success: boolean; message: string } | null>(null);
+
+ // Tooltip state
+ const [showTagsTooltip, setShowTagsTooltip] = useState(false);
+
+ const loadModels = useCallback(async () => {
+ try {
+ setLoading(true);
+ setError(null);
+ const result = await customModelsApi.getCustomModels();
+ setModels(result);
+ } catch (err) {
+ setError(err instanceof Error ? err.message : "Failed to load models");
+ } finally {
+ setLoading(false);
+ }
+ }, []);
+
+ useEffect(() => {
+ if (isOpen) {
+ loadModels();
+ // Get built-in models from init data
+ const initData = window.__SHELLEY_INIT__;
+ if (initData?.models) {
+ setBuiltInModels(initData.models.map((m) => m.id));
+ }
+ }
+ }, [isOpen, loadModels]);
+
+ const handleProviderChange = (provider: ProviderType) => {
+ setForm((prev) => ({
+ ...prev,
+ provider_type: provider,
+ endpoint: prev.endpoint_custom ? prev.endpoint : DEFAULT_ENDPOINTS[provider],
+ }));
+ };
+
+ const handleEndpointModeChange = (custom: boolean) => {
+ setForm((prev) => ({
+ ...prev,
+ endpoint_custom: custom,
+ endpoint: custom ? prev.endpoint : DEFAULT_ENDPOINTS[prev.provider_type],
+ }));
+ };
+
+ const handleSelectPresetModel = (preset: { name: string; model_name: string }) => {
+ setForm((prev) => ({
+ ...prev,
+ display_name: preset.name,
+ model_name: preset.model_name,
+ }));
+ };
+
+ const handleTest = async () => {
+ // Need model_name always, and either api_key or editing an existing model
+ if (!form.model_name) {
+ setTestResult({ success: false, message: "Model name is required" });
+ return;
+ }
+ if (!form.api_key && !editingModelId) {
+ setTestResult({ success: false, message: "API key is required" });
+ return;
+ }
+
+ setTesting(true);
+ setTestResult(null);
+
+ try {
+ const request: TestCustomModelRequest = {
+ model_id: editingModelId || undefined, // Pass model_id to use stored key
+ provider_type: form.provider_type,
+ endpoint: form.endpoint,
+ api_key: form.api_key,
+ model_name: form.model_name,
+ };
+ const result = await customModelsApi.testCustomModel(request);
+ setTestResult(result);
+ } catch (err) {
+ setTestResult({
+ success: false,
+ message: err instanceof Error ? err.message : "Test failed",
+ });
+ } finally {
+ setTesting(false);
+ }
+ };
+
+ const handleSave = async () => {
+ if (!form.display_name || !form.api_key || !form.model_name) {
+ setError("Display name, API key, and model name are required");
+ return;
+ }
+
+ try {
+ setError(null);
+ const request: CreateCustomModelRequest = {
+ display_name: form.display_name,
+ provider_type: form.provider_type,
+ endpoint: form.endpoint,
+ api_key: form.api_key,
+ model_name: form.model_name,
+ max_tokens: form.max_tokens,
+ tags: form.tags,
+ };
+
+ if (editingModelId) {
+ await customModelsApi.updateCustomModel(editingModelId, request);
+ } else {
+ await customModelsApi.createCustomModel(request);
+ }
+
+ setShowForm(false);
+ setEditingModelId(null);
+ setForm(emptyForm);
+ setTestResult(null);
+ await loadModels();
+ onModelsChanged?.();
+ } catch (err) {
+ setError(err instanceof Error ? err.message : "Failed to save model");
+ }
+ };
+
+ const handleEdit = (model: CustomModel) => {
+ setEditingModelId(model.model_id);
+ setForm({
+ display_name: model.display_name,
+ provider_type: model.provider_type,
+ endpoint: model.endpoint,
+ endpoint_custom: model.endpoint !== DEFAULT_ENDPOINTS[model.provider_type],
+ api_key: model.api_key,
+ model_name: model.model_name,
+ max_tokens: model.max_tokens,
+ tags: model.tags,
+ });
+ setShowForm(true);
+ setTestResult(null);
+ };
+
+ const handleDuplicate = async (model: CustomModel) => {
+ try {
+ setError(null);
+ await customModelsApi.duplicateCustomModel(model.model_id);
+ await loadModels();
+ onModelsChanged?.();
+ } catch (err) {
+ setError(err instanceof Error ? err.message : "Failed to duplicate model");
+ }
+ };
+
+ const handleDelete = async (modelId: string) => {
+ try {
+ setError(null);
+ await customModelsApi.deleteCustomModel(modelId);
+ await loadModels();
+ onModelsChanged?.();
+ } catch (err) {
+ setError(err instanceof Error ? err.message : "Failed to delete model");
+ }
+ };
+
+ const handleCancel = () => {
+ setShowForm(false);
+ setEditingModelId(null);
+ setForm(emptyForm);
+ setTestResult(null);
+ };
+
+ const handleAddNew = () => {
+ setEditingModelId(null);
+ setForm(emptyForm);
+ setShowForm(true);
+ setTestResult(null);
+ };
+
+ const headerRight = !showForm ? (
+ <button className="btn-primary btn-sm" onClick={handleAddNew}>
+ + Add Model
+ </button>
+ ) : null;
+
+ return (
+ <Modal
+ isOpen={isOpen}
+ onClose={onClose}
+ title="Manage Models"
+ titleRight={headerRight}
+ className="modal-wide"
+ >
+ <div className="models-modal">
+ {error && (
+ <div className="models-error">
+ {error}
+ <button onClick={() => setError(null)} className="models-error-dismiss">
+ ×
+ </button>
+ </div>
+ )}
+
+ {loading ? (
+ <div className="models-loading">
+ <div className="spinner"></div>
+ <span>Loading models...</span>
+ </div>
+ ) : showForm ? (
+ // Add/Edit form
+ <div className="model-form">
+ <h3>{editingModelId ? "Edit Model" : "Add Model"}</h3>
+
+ {/* Provider Selection */}
+ <div className="form-group">
+ <label>Provider / API Format</label>
+ <div className="provider-buttons">
+ {(["anthropic", "openai", "openai-responses", "gemini"] as ProviderType[]).map(
+ (p) => (
+ <button
+ key={p}
+ type="button"
+ className={`provider-btn ${form.provider_type === p ? "selected" : ""}`}
+ onClick={() => handleProviderChange(p)}
+ >
+ {PROVIDER_LABELS[p]}
+ </button>
+ ),
+ )}
+ </div>
+ </div>
+
+ {/* Endpoint Selection */}
+ <div className="form-group">
+ <label>Endpoint</label>
+ <div className="endpoint-toggle">
+ <button
+ type="button"
+ className={`toggle-btn ${!form.endpoint_custom ? "selected" : ""}`}
+ onClick={() => handleEndpointModeChange(false)}
+ >
+ Default
+ </button>
+ <button
+ type="button"
+ className={`toggle-btn ${form.endpoint_custom ? "selected" : ""}`}
+ onClick={() => handleEndpointModeChange(true)}
+ >
+ Custom
+ </button>
+ </div>
+ {form.endpoint_custom ? (
+ <input
+ type="text"
+ value={form.endpoint}
+ onChange={(e) => setForm((prev) => ({ ...prev, endpoint: e.target.value }))}
+ placeholder="https://..."
+ className="form-input"
+ />
+ ) : (
+ <div className="endpoint-display">{form.endpoint}</div>
+ )}
+ </div>
+
+ {/* Model Name with Presets */}
+ <div className="form-group">
+ <label>Model</label>
+ <div className="model-presets">
+ {DEFAULT_MODELS[form.provider_type].map((preset) => (
+ <button
+ key={preset.model_name}
+ type="button"
+ className={`preset-btn ${form.model_name === preset.model_name ? "selected" : ""}`}
+ onClick={() => handleSelectPresetModel(preset)}
+ >
+ {preset.name}
+ </button>
+ ))}
+ </div>
+ <input
+ type="text"
+ value={form.model_name}
+ onChange={(e) => setForm((prev) => ({ ...prev, model_name: e.target.value }))}
+ placeholder="Model name (e.g., claude-sonnet-4-5)"
+ className="form-input"
+ />
+ </div>
+
+ {/* Display Name */}
+ <div className="form-group">
+ <label>Display Name</label>
+ <input
+ type="text"
+ value={form.display_name}
+ onChange={(e) => setForm((prev) => ({ ...prev, display_name: e.target.value }))}
+ placeholder="Name shown in model selector"
+ className="form-input"
+ />
+ </div>
+
+ {/* API Key */}
+ <div className="form-group">
+ <label>API Key</label>
+ <input
+ type="text"
+ value={form.api_key}
+ onChange={(e) => setForm((prev) => ({ ...prev, api_key: e.target.value }))}
+ placeholder="Enter API key"
+ className="form-input"
+ autoComplete="off"
+ />
+ </div>
+
+ {/* Max Tokens */}
+ <div className="form-group">
+ <label>Max Context Tokens</label>
+ <input
+ type="number"
+ value={form.max_tokens}
+ onChange={(e) =>
+ setForm((prev) => ({ ...prev, max_tokens: parseInt(e.target.value) || 200000 }))
+ }
+ className="form-input"
+ />
+ </div>
+
+ {/* Tags */}
+ <div className="form-group">
+ <label>
+ Tags
+ <span
+ className="info-icon-wrapper"
+ onClick={(e) => {
+ e.preventDefault();
+ e.stopPropagation();
+ setShowTagsTooltip(!showTagsTooltip);
+ }}
+ >
+ <span className="info-icon">
+ <svg
+ fill="none"
+ stroke="currentColor"
+ viewBox="0 0 24 24"
+ width="14"
+ height="14"
+ >
+ <path
+ strokeLinecap="round"
+ strokeLinejoin="round"
+ strokeWidth={2}
+ d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
+ />
+ </svg>
+ </span>
+ {showTagsTooltip && (
+ <span className="info-tooltip">
+ Comma-separated tags for this model. Use "slug" to mark this model for
+ generating conversation titles. If no model has the "slug" tag, the
+ conversation's model will be used.
+ </span>
+ )}
+ </span>
+ </label>
+ <input
+ type="text"
+ value={form.tags}
+ onChange={(e) => setForm((prev) => ({ ...prev, tags: e.target.value }))}
+ placeholder="comma-separated, e.g., slug, cheap"
+ className="form-input"
+ />
+ </div>
+
+ {/* Test Result */}
+ {testResult && (
+ <div className={`test-result ${testResult.success ? "success" : "error"}`}>
+ {testResult.success ? "✓" : "✗"} {testResult.message}
+ </div>
+ )}
+
+ {/* Form Actions */}
+ <div className="form-actions">
+ <button type="button" className="btn-secondary" onClick={handleCancel}>
+ Cancel
+ </button>
+ <button
+ type="button"
+ className="btn-secondary"
+ onClick={handleTest}
+ disabled={testing || (!form.api_key && !editingModelId) || !form.model_name}
+ title={
+ !form.model_name
+ ? "Enter model name to test"
+ : !form.api_key && !editingModelId
+ ? "Enter API key to test"
+ : ""
+ }
+ >
+ {testing ? "Testing..." : "Test"}
+ </button>
+ <button
+ type="button"
+ className="btn-primary"
+ onClick={handleSave}
+ disabled={!form.display_name || !form.api_key || !form.model_name}
+ >
+ {editingModelId ? "Save" : "Add Model"}
+ </button>
+ </div>
+ </div>
+ ) : (
+ // Model List
+ <>
+ {models.length === 0 &&
+ builtInModels.length > 0 &&
+ builtInModels[0] !== "predictable" && (
+ <div className="models-info">
+ <p>Built-in models available:</p>
+ <ul className="builtin-list">
+ {builtInModels
+ .filter((m) => m !== "predictable")
+ .map((model) => (
+ <li key={model}>{model}</li>
+ ))}
+ </ul>
+ </div>
+ )}
+
+ {models.length > 0 && (
+ <div className="models-list">
+ {models.map((model) => (
+ <div key={model.model_id} className="model-card">
+ <div className="model-header">
+ <div className="model-info">
+ <span className="model-name">{model.display_name}</span>
+ <span className="model-provider">
+ {PROVIDER_LABELS[model.provider_type]}
+ </span>
+ {model.tags && (
+ <span className="model-badge" title={model.tags}>
+ {model.tags.split(",")[0]}
+ </span>
+ )}
+ </div>
+ <div className="model-actions">
+ <button
+ className="btn-icon"
+ onClick={() => handleDuplicate(model)}
+ title="Duplicate"
+ >
+ <svg
+ fill="none"
+ stroke="currentColor"
+ viewBox="0 0 24 24"
+ width="16"
+ height="16"
+ >
+ <path
+ strokeLinecap="round"
+ strokeLinejoin="round"
+ strokeWidth={2}
+ d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
+ />
+ </svg>
+ </button>
+ <button className="btn-icon" onClick={() => handleEdit(model)} title="Edit">
+ <svg
+ fill="none"
+ stroke="currentColor"
+ viewBox="0 0 24 24"
+ width="16"
+ height="16"
+ >
+ <path
+ strokeLinecap="round"
+ strokeLinejoin="round"
+ strokeWidth={2}
+ d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
+ />
+ </svg>
+ </button>
+ <button
+ className="btn-icon btn-danger"
+ onClick={() => handleDelete(model.model_id)}
+ title="Delete"
+ >
+ <svg
+ fill="none"
+ stroke="currentColor"
+ viewBox="0 0 24 24"
+ width="16"
+ height="16"
+ >
+ <path
+ strokeLinecap="round"
+ strokeLinejoin="round"
+ strokeWidth={2}
+ d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
+ />
+ </svg>
+ </button>
+ </div>
+ </div>
+ <div className="model-details">
+ <span className="model-api-name">{model.model_name}</span>
+ <span className="model-endpoint">{model.endpoint}</span>
+ </div>
+ </div>
+ ))}
+ </div>
+ )}
+ </>
+ )}
+ </div>
+ </Modal>
+ );
+}
+
+export default ModelsModal;
@@ -27,6 +27,16 @@ class ApiService {
return response.json();
}
+ async getModels(): Promise<
+ Array<{ id: string; display_name?: string; ready: boolean; max_context_tokens?: number }>
+ > {
+ const response = await fetch(`${this.baseUrl}/models`);
+ if (!response.ok) {
+ throw new Error(`Failed to get models: ${response.statusText}`);
+ }
+ return response.json();
+ }
+
async searchConversations(query: string): Promise<ConversationWithState[]> {
const params = new URLSearchParams({
q: query,
@@ -255,3 +265,115 @@ class ApiService {
}
export const api = new ApiService();
+
+// Custom models API
+export interface CustomModel {
+ model_id: string;
+ display_name: string;
+ provider_type: "anthropic" | "openai" | "openai-responses" | "gemini";
+ endpoint: string;
+ api_key: string;
+ model_name: string;
+ max_tokens: number;
+ tags: string; // Comma-separated tags (e.g., "slug" for slug generation)
+}
+
+export interface CreateCustomModelRequest {
+ display_name: string;
+ provider_type: "anthropic" | "openai" | "openai-responses" | "gemini";
+ endpoint: string;
+ api_key: string;
+ model_name: string;
+ max_tokens: number;
+ tags: string; // Comma-separated tags
+}
+
+export interface TestCustomModelRequest {
+ model_id?: string; // If provided with empty api_key, use stored key
+ provider_type: "anthropic" | "openai" | "openai-responses" | "gemini";
+ endpoint: string;
+ api_key: string;
+ model_name: string;
+}
+
+class CustomModelsApi {
+ private baseUrl = "/api";
+
+ private postHeaders = {
+ "Content-Type": "application/json",
+ "X-Shelley-Request": "1",
+ };
+
+ async getCustomModels(): Promise<CustomModel[]> {
+ const response = await fetch(`${this.baseUrl}/custom-models`);
+ if (!response.ok) {
+ throw new Error(`Failed to get custom models: ${response.statusText}`);
+ }
+ return response.json();
+ }
+
+ async createCustomModel(request: CreateCustomModelRequest): Promise<CustomModel> {
+ const response = await fetch(`${this.baseUrl}/custom-models`, {
+ method: "POST",
+ headers: this.postHeaders,
+ body: JSON.stringify(request),
+ });
+ if (!response.ok) {
+ throw new Error(`Failed to create custom model: ${response.statusText}`);
+ }
+ return response.json();
+ }
+
+ async updateCustomModel(
+ modelId: string,
+ request: Partial<CreateCustomModelRequest>,
+ ): Promise<CustomModel> {
+ const response = await fetch(`${this.baseUrl}/custom-models/${modelId}`, {
+ method: "PUT",
+ headers: this.postHeaders,
+ body: JSON.stringify(request),
+ });
+ if (!response.ok) {
+ throw new Error(`Failed to update custom model: ${response.statusText}`);
+ }
+ return response.json();
+ }
+
+ async deleteCustomModel(modelId: string): Promise<void> {
+ const response = await fetch(`${this.baseUrl}/custom-models/${modelId}`, {
+ method: "DELETE",
+ headers: { "X-Shelley-Request": "1" },
+ });
+ if (!response.ok) {
+ throw new Error(`Failed to delete custom model: ${response.statusText}`);
+ }
+ }
+
+ async duplicateCustomModel(modelId: string, displayName?: string): Promise<CustomModel> {
+ const response = await fetch(`${this.baseUrl}/custom-models/${modelId}/duplicate`, {
+ method: "POST",
+ headers: this.postHeaders,
+ body: JSON.stringify({ display_name: displayName }),
+ });
+ if (!response.ok) {
+ throw new Error(`Failed to duplicate custom model: ${response.statusText}`);
+ }
+ return response.json();
+ }
+
+ async testCustomModel(
+ request: TestCustomModelRequest,
+ ): Promise<{ success: boolean; message: string }> {
+ const response = await fetch(`${this.baseUrl}/custom-models-test`, {
+ method: "POST",
+ headers: this.postHeaders,
+ body: JSON.stringify(request),
+ });
+ if (!response.ok) {
+ throw new Error(`Failed to test custom model: ${response.statusText}`);
+ }
+ return response.json();
+ }
+}
+
+export const customModelsApi = new CustomModelsApi();
@@ -306,6 +306,11 @@ button {
background: var(--bg-tertiary);
}
+.btn-sm {
+ padding: 0.375rem 0.75rem;
+ font-size: 0.8125rem;
+}
+
.btn-icon {
padding: 0.5rem;
border-radius: 0.375rem;
@@ -1992,6 +1997,10 @@ button {
overflow: hidden;
}
+.modal.modal-wide {
+ max-width: 40rem;
+}
+
.modal-header {
display: flex;
align-items: center;
@@ -2005,8 +2014,15 @@ button {
font-weight: 600;
}
+.modal-title-right {
+ margin-left: auto;
+ margin-right: 1rem;
+}
+
.modal-body {
padding: 1rem;
+ overflow-y: auto;
+ max-height: calc(80vh - 60px);
}
/* Form Elements */
@@ -3795,6 +3811,158 @@ svg {
}
}
+/* Models Modal Styles */
+.models-modal {
+ padding: 0.5rem;
+}
+
+.models-loading {
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ gap: 0.5rem;
+ padding: 2rem;
+ color: var(--text-secondary);
+}
+
+.models-error {
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+ padding: 0.75rem 1rem;
+ background: var(--error-bg);
+ border: 1px solid var(--error-border);
+ border-radius: 0.375rem;
+ color: var(--error-text);
+ margin-bottom: 1rem;
+}
+
+.models-error-dismiss {
+ background: none;
+ border: none;
+ font-size: 1.25rem;
+ cursor: pointer;
+ color: inherit;
+ padding: 0;
+ line-height: 1;
+}
+
+.models-info {
+ padding: 0.75rem 1rem;
+ background: var(--bg-tertiary);
+ border-radius: 0.375rem;
+ margin-bottom: 1rem;
+}
+
+.models-info p {
+ margin: 0 0 0.5rem 0;
+ color: var(--text-secondary);
+}
+
+.builtin-list {
+ margin: 0;
+ padding-left: 1.25rem;
+ color: var(--text-primary);
+}
+
+.builtin-list li {
+ margin: 0.25rem 0;
+}
+
+.models-list {
+ display: flex;
+ flex-direction: column;
+ gap: 0.75rem;
+ margin-bottom: 1rem;
+}
+
+.model-card {
+ padding: 0.75rem 1rem;
+ background: var(--bg-base);
+ border: 1px solid var(--border);
+ border-radius: 0.375rem;
+}
+
+.model-header {
+ display: flex;
+ justify-content: space-between;
+ align-items: flex-start;
+ gap: 0.5rem;
+}
+
+.model-info {
+ display: flex;
+ flex-wrap: wrap;
+ align-items: center;
+ gap: 0.5rem;
+}
+
+.model-name {
+ font-weight: 500;
+ color: var(--text-primary);
+}
+
+.model-provider {
+ font-size: 0.75rem;
+ padding: 0.125rem 0.5rem;
+ background: var(--bg-tertiary);
+ border-radius: 0.25rem;
+ color: var(--text-secondary);
+}
+
+.model-badge {
+ font-size: 0.625rem;
+ text-transform: uppercase;
+ letter-spacing: 0.05em;
+ padding: 0.125rem 0.375rem;
+ background: var(--blue-bg);
+ border: 1px solid var(--blue-border);
+ border-radius: 0.25rem;
+ color: var(--blue-text);
+}
+
+.model-actions {
+ display: flex;
+ gap: 0.25rem;
+ flex-shrink: 0;
+}
+
+.model-details {
+ margin-top: 0.5rem;
+ display: flex;
+ flex-direction: column;
+ gap: 0.25rem;
+}
+
+.model-api-name {
+ font-size: 0.75rem;
+ font-family: var(--font-mono);
+ color: var(--text-secondary);
+}
+
+.model-endpoint {
+ font-size: 0.75rem;
+ color: var(--text-tertiary);
+ word-break: break-all;
+}
+
+.btn-icon {
+ padding: 0.25rem;
+ border-radius: 0.25rem;
+ color: var(--text-secondary);
+ background: transparent;
+ border: none;
+ cursor: pointer;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.btn-icon:hover {
+ background: var(--bg-tertiary);
+ color: var(--text-primary);
+}
+
/* Version Checker Styles */
.version-update-dot {
position: absolute;
@@ -4018,3 +4186,243 @@ svg {
.version-btn-primary:hover:not(:disabled) {
background: var(--primary-dark);
}
+
+.btn-icon.btn-danger:hover {
+ background: var(--error-bg);
+ color: var(--error-text);
+}
+
+.add-model-btn {
+ width: 100%;
+}
+
+/* Model Form */
+.model-form h3 {
+ margin: 0 0 0.75rem 0;
+ font-size: 1rem;
+ font-weight: 600;
+}
+
+.form-group {
+ margin-bottom: 0.75rem;
+}
+
+.form-group label {
+ display: block;
+ font-size: 0.875rem;
+ font-weight: 500;
+ margin-bottom: 0.25rem;
+ color: var(--text-primary);
+}
+
+.form-group label .optional {
+ font-weight: 400;
+ color: var(--text-secondary);
+}
+
+.form-input {
+ width: 100%;
+ padding: 0.5rem 0.75rem;
+ background: var(--bg-base);
+ border: 1px solid var(--border);
+ border-radius: 0.375rem;
+ color: var(--text-primary);
+ font-family: inherit;
+ font-size: 0.875rem;
+}
+
+.form-input:focus {
+ outline: none;
+ border-color: var(--primary);
+ box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2);
+}
+
+.form-checkbox {
+ display: flex;
+ align-items: center;
+}
+
+.form-checkbox label {
+ display: flex;
+ align-items: center;
+ gap: 0.5rem;
+ cursor: pointer;
+ margin: 0;
+}
+
+.form-checkbox input[type="checkbox"] {
+ width: 1rem;
+ height: 1rem;
+ cursor: pointer;
+}
+
+.info-icon-wrapper {
+ position: relative;
+ display: inline-flex;
+ align-items: center;
+ margin-left: 0.25rem;
+}
+
+.info-icon {
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ color: var(--text-tertiary);
+ cursor: pointer;
+}
+
+.info-icon:hover {
+ color: var(--text-secondary);
+}
+
+.info-tooltip {
+ position: absolute;
+ top: 50%;
+ left: 100%;
+ transform: translateY(-50%);
+ background: var(--bg-tertiary);
+ border: 1px solid var(--border);
+ border-radius: 0.375rem;
+ padding: 0.5rem 0.75rem;
+ font-size: 0.75rem;
+ font-weight: 400;
+ color: var(--text-primary);
+ white-space: normal;
+ width: 220px;
+ margin-left: 0.375rem;
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
+ z-index: 100;
+}
+
+.info-tooltip::after {
+ content: "";
+ position: absolute;
+ top: 50%;
+ right: 100%;
+ transform: translateY(-50%);
+ border: 6px solid transparent;
+ border-right-color: var(--border);
+}
+
+.provider-buttons {
+ display: flex;
+ gap: 0.5rem;
+}
+
+.provider-btn {
+ flex: 1;
+ padding: 0.5rem;
+ background: var(--bg-base);
+ border: 1px solid var(--border);
+ border-radius: 0.375rem;
+ color: var(--text-primary);
+ font-size: 0.875rem;
+ cursor: pointer;
+ transition: all 0.15s;
+}
+
+.provider-btn:hover {
+ background: var(--bg-tertiary);
+}
+
+.provider-btn.selected {
+ background: var(--primary);
+ border-color: var(--primary);
+ color: white;
+}
+
+.endpoint-toggle {
+ display: flex;
+ gap: 0;
+ margin-bottom: 0.5rem;
+}
+
+.toggle-btn {
+ flex: 1;
+ padding: 0.375rem 0.75rem;
+ background: var(--bg-base);
+ border: 1px solid var(--border);
+ color: var(--text-secondary);
+ font-size: 0.75rem;
+ cursor: pointer;
+ transition: all 0.15s;
+}
+
+.toggle-btn:first-child {
+ border-radius: 0.375rem 0 0 0.375rem;
+}
+
+.toggle-btn:last-child {
+ border-radius: 0 0.375rem 0.375rem 0;
+ border-left: none;
+}
+
+.toggle-btn.selected {
+ background: var(--bg-tertiary);
+ color: var(--text-primary);
+}
+
+.endpoint-display {
+ padding: 0.5rem 0.75rem;
+ background: var(--bg-tertiary);
+ border-radius: 0.375rem;
+ font-size: 0.75rem;
+ color: var(--text-secondary);
+ word-break: break-all;
+}
+
+.model-presets {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 0.375rem;
+ margin-bottom: 0.5rem;
+}
+
+.preset-btn {
+ padding: 0.25rem 0.625rem;
+ background: var(--bg-base);
+ border: 1px solid var(--border);
+ border-radius: 0.25rem;
+ color: var(--text-secondary);
+ font-size: 0.75rem;
+ cursor: pointer;
+ transition: all 0.15s;
+}
+
+.preset-btn:hover {
+ background: var(--bg-tertiary);
+ color: var(--text-primary);
+}
+
+.preset-btn.selected {
+ background: var(--blue-bg);
+ border-color: var(--blue-border);
+ color: var(--blue-text);
+}
+
+.test-result {
+ padding: 0.75rem 1rem;
+ border-radius: 0.375rem;
+ margin-bottom: 1rem;
+ font-size: 0.875rem;
+}
+
+.test-result.success {
+ background: var(--success-bg);
+ border: 1px solid var(--success-border);
+ color: var(--success-text);
+}
+
+.test-result.error {
+ background: var(--error-bg);
+ border: 1px solid var(--error-border);
+ color: var(--error-text);
+}
+
+.form-actions {
+ display: flex;
+ gap: 0.5rem;
+ justify-content: flex-end;
+ margin-top: 1rem;
+ padding-bottom: 0.5rem;
+}
@@ -49,6 +49,7 @@ export interface LLMContent {
// API types
export interface Model {
id: string;
+ display_name?: string;
ready: boolean;
max_context_tokens?: number;
}