diff --git a/AGENTS.md b/AGENTS.md index 9203f930028e00e54849f1963328950ee2901452..d0a64123d75c6494362d23c32cfd5b35bf18651f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -27,3 +27,5 @@ ./bin/shelley -config /exe.dev/shelley.json -db /tmp/shelley-test.db serve -port 8002 ``` Then use browser tools to navigate to http://localhost:8002/ and interact with the UI. +13. NEVER use alert(), confirm(), or prompt(). Use proper UI components like tooltips, modals, or toasts instead. +14. SQL migrations and frontend changes require rebuilding the binary (`make build` or `go generate ./... && cd ui && pnpm run build`). diff --git a/cmd/shelley/main.go b/cmd/shelley/main.go index f0e450a4a344c352d442b30cbdfb7a0cf2a04baf..34c17cb6ac9e31cb3335a6747de1b372f9128c83 100644 --- a/cmd/shelley/main.go +++ b/cmd/shelley/main.go @@ -101,7 +101,7 @@ func runServe(global GlobalConfig, args []string) { // Build LLM configuration llmConfig := buildLLMConfig(logger, global.ConfigPath, global.TerminalURL, global.DefaultModel, database) - // Initialize LLM service manager + // Initialize LLM service manager (includes custom model support via database) llmManager := server.NewLLMServiceManager(llmConfig) // Log available models diff --git a/db/db.go b/db/db.go index d95a2baf0f718203d33f842dea9c7af0def7a032..9f7e3fdd71d1cfe13a88407d09abdaeaba175476 100644 --- a/db/db.go +++ b/db/db.go @@ -115,6 +115,20 @@ func (db *DB) Migrate(ctx context.Context) error { // Sort migrations by number sort.Strings(migrations) + // Check for duplicate migration numbers + seenNumbers := make(map[string]string) // number -> filename + for _, migration := range migrations { + matches := migrationPattern.FindStringSubmatch(migration) + if len(matches) < 2 { + continue + } + num := matches[1] + if existing, ok := seenNumbers[num]; ok { + return fmt.Errorf("duplicate migration number %s: %s and %s", num, existing, migration) + } + seenNumbers[num] = migration + } + // Get executed migrations executedMigrations := make(map[int]bool) var tableName string @@ -850,3 +864,68 @@ func reconstructRequestBody(ctx context.Context, q *generated.Queries, requestID *result = parentBody[:prefixLen] + suffix return nil } + +// GetModels returns all models from the database +func (db *DB) GetModels(ctx context.Context) ([]generated.Model, error) { + var models []generated.Model + err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error { + q := generated.New(rx.Conn()) + var err error + models, err = q.GetModels(ctx) + return err + }) + return models, err +} + +// GetModel returns a model by ID +func (db *DB) GetModel(ctx context.Context, modelID string) (*generated.Model, error) { + var model generated.Model + err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error { + q := generated.New(rx.Conn()) + var err error + model, err = q.GetModel(ctx, modelID) + return err + }) + if err != nil { + return nil, err + } + return &model, nil +} + +// CreateModel creates a new model +func (db *DB) CreateModel(ctx context.Context, params generated.CreateModelParams) (*generated.Model, error) { + var model generated.Model + err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error { + q := generated.New(tx.Conn()) + var err error + model, err = q.CreateModel(ctx, params) + return err + }) + if err != nil { + return nil, err + } + return &model, nil +} + +// UpdateModel updates a model +func (db *DB) UpdateModel(ctx context.Context, params generated.UpdateModelParams) (*generated.Model, error) { + var model generated.Model + err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error { + q := generated.New(tx.Conn()) + var err error + model, err = q.UpdateModel(ctx, params) + return err + }) + if err != nil { + return nil, err + } + return &model, nil +} + +// DeleteModel deletes a model +func (db *DB) DeleteModel(ctx context.Context, modelID string) error { + return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error { + q := generated.New(tx.Conn()) + return q.DeleteModel(ctx, modelID) + }) +} diff --git a/db/generated/llm_requests.sql.go b/db/generated/llm_requests.sql.go index 599d6023040b91d281b84f32b1d87048be6e1ac5..b6ccb368f571721fc9930fec8f3efba9da333e1a 100644 --- a/db/generated/llm_requests.sql.go +++ b/db/generated/llm_requests.sql.go @@ -150,22 +150,24 @@ func (q *Queries) InsertLLMRequest(ctx context.Context, arg InsertLLMRequestPara } const listRecentLLMRequests = `-- name: ListRecentLLMRequests :many -SELECT - id, - conversation_id, - model, - provider, - url, - LENGTH(request_body) as request_body_length, - LENGTH(response_body) as response_body_length, - status_code, - error, - duration_ms, - created_at, - prefix_request_id, - prefix_length -FROM llm_requests -ORDER BY id DESC +SELECT + r.id, + r.conversation_id, + r.model, + m.display_name as model_display_name, + r.provider, + r.url, + LENGTH(r.request_body) as request_body_length, + LENGTH(r.response_body) as response_body_length, + r.status_code, + r.error, + r.duration_ms, + r.created_at, + r.prefix_request_id, + r.prefix_length +FROM llm_requests r +LEFT JOIN models m ON r.model = m.model_id +ORDER BY r.id DESC LIMIT ? ` @@ -173,6 +175,7 @@ type ListRecentLLMRequestsRow struct { ID int64 `json:"id"` ConversationID *string `json:"conversation_id"` Model string `json:"model"` + ModelDisplayName *string `json:"model_display_name"` Provider string `json:"provider"` Url string `json:"url"` RequestBodyLength *int64 `json:"request_body_length"` @@ -198,6 +201,7 @@ func (q *Queries) ListRecentLLMRequests(ctx context.Context, limit int64) ([]Lis &i.ID, &i.ConversationID, &i.Model, + &i.ModelDisplayName, &i.Provider, &i.Url, &i.RequestBodyLength, diff --git a/db/generated/models.go b/db/generated/models.go index 4bd0ad9a39e9e498850b1886f7aeb214291b8c49..4fcfda92a6481fb51e28ddd27f9cf5b6c669242f 100644 --- a/db/generated/models.go +++ b/db/generated/models.go @@ -52,3 +52,16 @@ type Migration struct { MigrationName string `json:"migration_name"` ExecutedAt *time.Time `json:"executed_at"` } + +type Model struct { + ModelID string `json:"model_id"` + DisplayName string `json:"display_name"` + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + ApiKey string `json:"api_key"` + ModelName string `json:"model_name"` + MaxTokens int64 `json:"max_tokens"` + Tags string `json:"tags"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/db/generated/models.sql.go b/db/generated/models.sql.go new file mode 100644 index 0000000000000000000000000000000000000000..2abe2139f8e27c9c386a7c6351908063cb96b8f1 --- /dev/null +++ b/db/generated/models.sql.go @@ -0,0 +1,175 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: models.sql + +package generated + +import ( + "context" +) + +const createModel = `-- name: CreateModel :one +INSERT INTO models (model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags) +VALUES (?, ?, ?, ?, ?, ?, ?, ?) +RETURNING model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at +` + +type CreateModelParams struct { + ModelID string `json:"model_id"` + DisplayName string `json:"display_name"` + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + ApiKey string `json:"api_key"` + ModelName string `json:"model_name"` + MaxTokens int64 `json:"max_tokens"` + Tags string `json:"tags"` +} + +func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) { + row := q.db.QueryRowContext(ctx, createModel, + arg.ModelID, + arg.DisplayName, + arg.ProviderType, + arg.Endpoint, + arg.ApiKey, + arg.ModelName, + arg.MaxTokens, + arg.Tags, + ) + var i Model + err := row.Scan( + &i.ModelID, + &i.DisplayName, + &i.ProviderType, + &i.Endpoint, + &i.ApiKey, + &i.ModelName, + &i.MaxTokens, + &i.Tags, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteModel = `-- name: DeleteModel :exec +DELETE FROM models WHERE model_id = ? +` + +func (q *Queries) DeleteModel(ctx context.Context, modelID string) error { + _, err := q.db.ExecContext(ctx, deleteModel, modelID) + return err +} + +const getModel = `-- name: GetModel :one +SELECT model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at FROM models WHERE model_id = ? +` + +func (q *Queries) GetModel(ctx context.Context, modelID string) (Model, error) { + row := q.db.QueryRowContext(ctx, getModel, modelID) + var i Model + err := row.Scan( + &i.ModelID, + &i.DisplayName, + &i.ProviderType, + &i.Endpoint, + &i.ApiKey, + &i.ModelName, + &i.MaxTokens, + &i.Tags, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getModels = `-- name: GetModels :many +SELECT model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at FROM models ORDER BY created_at ASC +` + +func (q *Queries) GetModels(ctx context.Context) ([]Model, error) { + rows, err := q.db.QueryContext(ctx, getModels) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Model{} + for rows.Next() { + var i Model + if err := rows.Scan( + &i.ModelID, + &i.DisplayName, + &i.ProviderType, + &i.Endpoint, + &i.ApiKey, + &i.ModelName, + &i.MaxTokens, + &i.Tags, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateModel = `-- name: UpdateModel :one +UPDATE models +SET display_name = ?, + provider_type = ?, + endpoint = ?, + api_key = ?, + model_name = ?, + max_tokens = ?, + tags = ?, + updated_at = CURRENT_TIMESTAMP +WHERE model_id = ? +RETURNING model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags, created_at, updated_at +` + +type UpdateModelParams struct { + DisplayName string `json:"display_name"` + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + ApiKey string `json:"api_key"` + ModelName string `json:"model_name"` + MaxTokens int64 `json:"max_tokens"` + Tags string `json:"tags"` + ModelID string `json:"model_id"` +} + +func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) { + row := q.db.QueryRowContext(ctx, updateModel, + arg.DisplayName, + arg.ProviderType, + arg.Endpoint, + arg.ApiKey, + arg.ModelName, + arg.MaxTokens, + arg.Tags, + arg.ModelID, + ) + var i Model + err := row.Scan( + &i.ModelID, + &i.DisplayName, + &i.ProviderType, + &i.Endpoint, + &i.ApiKey, + &i.ModelName, + &i.MaxTokens, + &i.Tags, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/db/query/llm_requests.sql b/db/query/llm_requests.sql index 3f1b8c10db355687a4f1859e5522bf9fbc59655f..fa86bd8a447d0f4ba45fdd0600d41f6ad0a40174 100644 --- a/db/query/llm_requests.sql +++ b/db/query/llm_requests.sql @@ -24,22 +24,24 @@ LIMIT 1; SELECT * FROM llm_requests WHERE id = ?; -- name: ListRecentLLMRequests :many -SELECT - id, - conversation_id, - model, - provider, - url, - LENGTH(request_body) as request_body_length, - LENGTH(response_body) as response_body_length, - status_code, - error, - duration_ms, - created_at, - prefix_request_id, - prefix_length -FROM llm_requests -ORDER BY id DESC +SELECT + r.id, + r.conversation_id, + r.model, + m.display_name as model_display_name, + r.provider, + r.url, + LENGTH(r.request_body) as request_body_length, + LENGTH(r.response_body) as response_body_length, + r.status_code, + r.error, + r.duration_ms, + r.created_at, + r.prefix_request_id, + r.prefix_length +FROM llm_requests r +LEFT JOIN models m ON r.model = m.model_id +ORDER BY r.id DESC LIMIT ?; -- name: GetLLMRequestBody :one diff --git a/db/query/models.sql b/db/query/models.sql new file mode 100644 index 0000000000000000000000000000000000000000..1815da9b7ca6fc4eee337caefbb2c8a020eaf197 --- /dev/null +++ b/db/query/models.sql @@ -0,0 +1,26 @@ +-- name: GetModels :many +SELECT * FROM models ORDER BY created_at ASC; + +-- name: GetModel :one +SELECT * FROM models WHERE model_id = ?; + +-- name: CreateModel :one +INSERT INTO models (model_id, display_name, provider_type, endpoint, api_key, model_name, max_tokens, tags) +VALUES (?, ?, ?, ?, ?, ?, ?, ?) +RETURNING *; + +-- name: UpdateModel :one +UPDATE models +SET display_name = ?, + provider_type = ?, + endpoint = ?, + api_key = ?, + model_name = ?, + max_tokens = ?, + tags = ?, + updated_at = CURRENT_TIMESTAMP +WHERE model_id = ? +RETURNING *; + +-- name: DeleteModel :exec +DELETE FROM models WHERE model_id = ?; diff --git a/db/schema/012-custom-models.sql b/db/schema/012-custom-models.sql new file mode 100644 index 0000000000000000000000000000000000000000..358a0e458ba281440e5683abe7508ba6b8566582 --- /dev/null +++ b/db/schema/012-custom-models.sql @@ -0,0 +1,15 @@ +-- Models table +-- Stores user-configured LLM models with API keys + +CREATE TABLE models ( + model_id TEXT PRIMARY KEY, + display_name TEXT NOT NULL, + provider_type TEXT NOT NULL CHECK (provider_type IN ('anthropic', 'openai', 'openai-responses', 'gemini')), + endpoint TEXT NOT NULL, + api_key TEXT NOT NULL, + model_name TEXT NOT NULL, -- The actual model name sent to the API (e.g., "claude-sonnet-4-5-20250514") + max_tokens INTEGER NOT NULL DEFAULT 200000, + tags TEXT NOT NULL DEFAULT '', -- Comma-separated tags (e.g., "slug" for slug generation) + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/llm/oai/oai.go b/llm/oai/oai.go index fa0504d2f16b7e6a18ac2426cf1769ebcb75c442..4349630d766a25e6b5c9aabac6b7c4b39162917b 100644 --- a/llm/oai/oai.go +++ b/llm/oai/oai.go @@ -292,6 +292,13 @@ var ( APIKeyEnv: OpenAIAPIKeyEnv, } + GPT52Codex = Model{ + UserName: "gpt-5.2-codex", + ModelName: "gpt-5.2-codex", + URL: OpenAIURL, + APIKeyEnv: OpenAIAPIKeyEnv, + } + // Skaband-specific model names. // Provider details (URL and APIKeyEnv) are handled by skaband Qwen = Model{ @@ -329,6 +336,7 @@ var ModelsRegistry = []Model{ GPT5Mini, GPT5Nano, GPT5Codex, + GPT52Codex, O3, O4Mini, Gemini25Flash, @@ -525,22 +533,6 @@ func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage { return messages } -// requiresMaxCompletionTokens returns true if the model requires max_completion_tokens instead of max_tokens. -func (m Model) requiresMaxCompletionTokens() bool { - // Reasoning models always use max_completion_tokens - if m.IsReasoningModel { - return true - } - - // GPT-5 series models also require max_completion_tokens - switch m.ModelName { - case "gpt-5.1", "gpt-5.1-mini", "gpt-5.1-nano": - return true - default: - return false - } -} - // fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI. func fromLLMToolChoice(tc *llm.ToolChoice) any { if tc == nil { @@ -812,15 +804,11 @@ func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error // Create the OpenAI request req := openai.ChatCompletionRequest{ - Model: model.ModelName, - Messages: allMessages, - Tools: tools, - ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible - } - if model.requiresMaxCompletionTokens() { - req.MaxCompletionTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens) - } else { - req.MaxTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens) + Model: model.ModelName, + Messages: allMessages, + Tools: tools, + ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible + MaxCompletionTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens), } // Construct the full URL for logging and debugging fullURL := baseURL + "/chat/completions" diff --git a/llm/oai/oai_responses.go b/llm/oai/oai_responses.go index 3fedd053f0fd2e3ae02b1a3002bcae2b8b89c6e3..9b58ac9449761191785c6b309b42322a6075b97d 100644 --- a/llm/oai/oai_responses.go +++ b/llm/oai/oai_responses.go @@ -340,6 +340,8 @@ func (s *ResponsesService) TokenContextWindow() int { // Use the same context window logic as the regular service switch model.ModelName { + case "gpt-5.2-codex": + return 272000 // 272k for gpt-5.2-codex case "gpt-5.1-codex": return 256000 // 256k for gpt-5.1-codex case "gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14": diff --git a/llm/oai/oai_responses_test.go b/llm/oai/oai_responses_test.go index 8e696cee30f953ccf9d34060f96b9028cae59795..d47492b11410004c9da94c0a6c1e62ce001e6693 100644 --- a/llm/oai/oai_responses_test.go +++ b/llm/oai/oai_responses_test.go @@ -284,6 +284,7 @@ func TestResponsesServiceTokenContextWindow(t *testing.T) { model Model expected int }{ + {model: GPT52Codex, expected: 272000}, {model: GPT5Codex, expected: 256000}, {model: GPT41, expected: 200000}, {model: GPT4o, expected: 128000}, diff --git a/llm/oai/oai_test.go b/llm/oai/oai_test.go index e631e061b9a600856862ca57433dd24126275f84..76e0648be6d69314aea4e6e9d5b34b76473f2a1a 100644 --- a/llm/oai/oai_test.go +++ b/llm/oai/oai_test.go @@ -12,106 +12,6 @@ import ( "shelley.exe.dev/llm" ) -func TestRequiresMaxCompletionTokens(t *testing.T) { - tests := []struct { - name string - model Model - expected bool - }{ - { - name: "GPT-5 requires max_completion_tokens", - model: GPT5, - expected: true, - }, - { - name: "GPT-5 Mini requires max_completion_tokens", - model: GPT5Mini, - expected: true, - }, - { - name: "O3 reasoning model requires max_completion_tokens", - model: O3, - expected: true, - }, - { - name: "O4-mini reasoning model requires max_completion_tokens", - model: O4Mini, - expected: true, - }, - { - name: "GPT-4.1 uses max_tokens", - model: GPT41, - expected: false, - }, - { - name: "GPT-4o uses max_tokens", - model: GPT4o, - expected: false, - }, - { - name: "GPT-4o Mini uses max_tokens", - model: GPT4oMini, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.model.requiresMaxCompletionTokens() - if result != tt.expected { - t.Errorf("requiresMaxCompletionTokens() = %v, expected %v", result, tt.expected) - } - }) - } -} - -func TestRequestParameterGeneration(t *testing.T) { - // Test that we can generate the correct request structure without making API calls - tests := []struct { - name string - model Model - expectMaxTokens bool - expectMaxCompletionTokens bool - }{ - { - name: "GPT-5 uses max_completion_tokens", - model: GPT5, - expectMaxTokens: false, - expectMaxCompletionTokens: true, - }, - { - name: "GPT-5 Mini uses max_completion_tokens", - model: GPT5Mini, - expectMaxTokens: false, - expectMaxCompletionTokens: true, - }, - { - name: "GPT-4.1 uses max_tokens", - model: GPT41, - expectMaxTokens: true, - expectMaxCompletionTokens: false, - }, - { - name: "O3 uses max_completion_tokens", - model: O3, - expectMaxTokens: false, - expectMaxCompletionTokens: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - usesMaxCompletionTokens := tt.model.requiresMaxCompletionTokens() - if tt.expectMaxCompletionTokens && !usesMaxCompletionTokens { - t.Errorf("Expected model %s to use max_completion_tokens, but it doesn't", tt.model.UserName) - } - if tt.expectMaxTokens && usesMaxCompletionTokens { - t.Errorf("Expected model %s to use max_tokens, but it uses max_completion_tokens", tt.model.UserName) - } - }) - } -} - func TestToRoleFromString(t *testing.T) { tests := []struct { name string diff --git a/models/models.go b/models/models.go index 9bf92e18e81b20af00432f5b8bb1acfcb2d86dbc..503cbab9475af9ae407f4ef1e5556f2a03257d95 100644 --- a/models/models.go +++ b/models/models.go @@ -148,47 +148,15 @@ func All() []Model { }, }, { - ID: "gpt-5", + ID: "gpt-5.2-codex", Provider: ProviderOpenAI, - Description: "GPT-5", + Description: "GPT-5.2 Codex", RequiredEnvVars: []string{"OPENAI_API_KEY"}, Factory: func(config *Config, httpc *http.Client) (llm.Service, error) { if config.OpenAIAPIKey == "" { - return nil, fmt.Errorf("gpt-5 requires OPENAI_API_KEY") + return nil, fmt.Errorf("gpt-5.2-codex requires OPENAI_API_KEY") } - svc := &oai.Service{Model: oai.GPT5, APIKey: config.OpenAIAPIKey, HTTPC: httpc} - if url := config.getOpenAIURL(); url != "" { - svc.ModelURL = url - } - return svc, nil - }, - }, - { - ID: "gpt-5-nano", - Provider: ProviderOpenAI, - Description: "GPT-5 Nano", - RequiredEnvVars: []string{"OPENAI_API_KEY"}, - Factory: func(config *Config, httpc *http.Client) (llm.Service, error) { - if config.OpenAIAPIKey == "" { - return nil, fmt.Errorf("gpt-5-nano requires OPENAI_API_KEY") - } - svc := &oai.Service{Model: oai.GPT5Nano, APIKey: config.OpenAIAPIKey, HTTPC: httpc} - if url := config.getOpenAIURL(); url != "" { - svc.ModelURL = url - } - return svc, nil - }, - }, - { - ID: "gpt-5.1-codex", - Provider: ProviderOpenAI, - Description: "GPT-5.1 Codex (uses Responses API)", - RequiredEnvVars: []string{"OPENAI_API_KEY"}, - Factory: func(config *Config, httpc *http.Client) (llm.Service, error) { - if config.OpenAIAPIKey == "" { - return nil, fmt.Errorf("gpt-5.1-codex requires OPENAI_API_KEY") - } - svc := &oai.ResponsesService{Model: oai.GPT5Codex, APIKey: config.OpenAIAPIKey, HTTPC: httpc} + svc := &oai.ResponsesService{Model: oai.GPT52Codex, APIKey: config.OpenAIAPIKey, HTTPC: httpc} if url := config.getOpenAIURL(); url != "" { svc.ModelURL = url } @@ -259,38 +227,6 @@ func All() []Model { return svc, nil }, }, - { - ID: "gemini-2.5-pro", - Provider: ProviderGemini, - Description: "Gemini 2.5 Pro", - RequiredEnvVars: []string{"GEMINI_API_KEY"}, - Factory: func(config *Config, httpc *http.Client) (llm.Service, error) { - if config.GeminiAPIKey == "" { - return nil, fmt.Errorf("gemini-2.5-pro requires GEMINI_API_KEY") - } - svc := &gem.Service{APIKey: config.GeminiAPIKey, Model: "gemini-2.5-pro", HTTPC: httpc} - if url := config.getGeminiURL(); url != "" { - svc.URL = url - } - return svc, nil - }, - }, - { - ID: "gemini-2.5-flash", - Provider: ProviderGemini, - Description: "Gemini 2.5 Flash", - RequiredEnvVars: []string{"GEMINI_API_KEY"}, - Factory: func(config *Config, httpc *http.Client) (llm.Service, error) { - if config.GeminiAPIKey == "" { - return nil, fmt.Errorf("gemini-2.5-flash requires GEMINI_API_KEY") - } - svc := &gem.Service{APIKey: config.GeminiAPIKey, Model: "gemini-2.5-flash", HTTPC: httpc} - if url := config.getGeminiURL(); url != "" { - svc.URL = url - } - return svc, nil - }, - }, { ID: "predictable", Provider: ProviderBuiltIn, @@ -332,7 +268,8 @@ func Default() Model { type Manager struct { services map[string]serviceEntry logger *slog.Logger - db *db.DB + db *db.DB // for custom models and LLM request recording + httpc *http.Client // HTTP client with recording middleware } type serviceEntry struct { @@ -502,6 +439,9 @@ func NewManager(cfg *Config) (*Manager, error) { httpc = llmhttp.NewClient(nil, nil) } + // Store the HTTP client for use with custom models + manager.httpc = httpc + for _, model := range All() { svc, err := model.Factory(cfg, httpc) if err != nil { @@ -520,6 +460,34 @@ func NewManager(cfg *Config) (*Manager, error) { // GetService returns the LLM service for the given model ID, wrapped with logging func (m *Manager) GetService(modelID string) (llm.Service, error) { + // Check custom models first if we have a database + if m.db != nil { + dbModels, err := m.db.GetModels(context.Background()) + if err == nil && len(dbModels) > 0 { + // Custom models exist - only serve custom models, not built-in ones + for _, model := range dbModels { + if model.ModelID == modelID { + svc := m.createServiceFromModel(&model) + if svc != nil { + if m.logger != nil { + return &loggingService{ + service: svc, + logger: m.logger, + modelID: modelID, + provider: Provider(model.ProviderType), + db: m.db, + }, nil + } + return svc, nil + } + } + } + // Custom models exist but this model ID wasn't found among them + return nil, fmt.Errorf("unsupported model: %s", modelID) + } + } + + // No custom models - fall back to built-in models if entry, ok := m.services[modelID]; ok { // Wrap with logging if we have a logger if m.logger != nil { @@ -538,9 +506,20 @@ func (m *Manager) GetService(modelID string) (llm.Service, error) { // GetAvailableModels returns a list of available model IDs in the same order as All() func (m *Manager) GetAvailableModels() []string { - // Return IDs in the same order as All() for consistency - all := All() var ids []string + + // If we have custom models in the database, use ONLY those + if m.db != nil { + if dbModels, err := m.db.GetModels(context.Background()); err == nil && len(dbModels) > 0 { + for _, model := range dbModels { + ids = append(ids, model.ModelID) + } + return ids + } + } + + // No custom models - fall back to built-in models in the same order as All() + all := All() for _, model := range all { if _, ok := m.services[model.ID]; ok { ids = append(ids, model.ID) @@ -551,6 +530,80 @@ func (m *Manager) GetAvailableModels() []string { // HasModel reports whether the manager has a service for the given model ID func (m *Manager) HasModel(modelID string) bool { + // Check custom models first + if m.db != nil { + if model, err := m.db.GetModel(context.Background(), modelID); err == nil && model != nil { + return true + } + } _, ok := m.services[modelID] return ok } + +// ModelInfo contains display name and tags for a model +type ModelInfo struct { + DisplayName string + Tags string +} + +// GetModelInfo returns the display name and tags for a model +func (m *Manager) GetModelInfo(modelID string) *ModelInfo { + if m.db == nil { + return nil + } + model, err := m.db.GetModel(context.Background(), modelID) + if err != nil { + return nil + } + return &ModelInfo{ + DisplayName: model.DisplayName, + Tags: model.Tags, + } +} + +// createServiceFromModel creates an LLM service from a database model configuration +func (m *Manager) createServiceFromModel(model *generated.Model) llm.Service { + switch model.ProviderType { + case "anthropic": + return &ant.Service{ + APIKey: model.ApiKey, + URL: model.Endpoint, + Model: model.ModelName, + HTTPC: m.httpc, + } + case "openai": + return &oai.Service{ + APIKey: model.ApiKey, + ModelURL: model.Endpoint, + Model: oai.Model{ + ModelName: model.ModelName, + URL: model.Endpoint, + }, + MaxTokens: int(model.MaxTokens), + HTTPC: m.httpc, + } + case "openai-responses": + return &oai.ResponsesService{ + APIKey: model.ApiKey, + ModelURL: model.Endpoint, + Model: oai.Model{ + ModelName: model.ModelName, + URL: model.Endpoint, + }, + MaxTokens: int(model.MaxTokens), + HTTPC: m.httpc, + } + case "gemini": + return &gem.Service{ + APIKey: model.ApiKey, + URL: model.Endpoint, + Model: model.ModelName, + HTTPC: m.httpc, + } + default: + if m.logger != nil { + m.logger.Error("Unknown provider type for model", "model_id", model.ModelID, "provider_type", model.ProviderType) + } + return nil + } +} diff --git a/models/models_test.go b/models/models_test.go index fd0c792c69107ecd6536fc6edabcd3867fe3664a..5baf3b2f5ccea33444d7d62f2a1dc4ed802b2831 100644 --- a/models/models_test.go +++ b/models/models_test.go @@ -36,7 +36,7 @@ func TestByID(t *testing.T) { wantNil bool }{ {id: "qwen3-coder-fireworks", wantID: "qwen3-coder-fireworks", wantNil: false}, - {id: "gpt-5", wantID: "gpt-5", wantNil: false}, + {id: "gpt-5.2-codex", wantID: "gpt-5.2-codex", wantNil: false}, {id: "claude-sonnet-4.5", wantID: "claude-sonnet-4.5", wantNil: false}, {id: "claude-haiku-4.5", wantID: "claude-haiku-4.5", wantNil: false}, {id: "claude-opus-4.5", wantID: "claude-opus-4.5", wantNil: false}, diff --git a/server/cancel_claude_test.go b/server/cancel_claude_test.go index c6d9344388c0fdd0679cc16f5ef1115508e68e7c..4736e10d8fd72190c5c06882b34ef42599c6a2f2 100644 --- a/server/cancel_claude_test.go +++ b/server/cancel_claude_test.go @@ -19,6 +19,7 @@ import ( "shelley.exe.dev/db/generated" "shelley.exe.dev/llm" "shelley.exe.dev/llm/ant" + "shelley.exe.dev/models" ) // ClaudeTestHarness extends TestHarness with Claude-specific functionality @@ -544,6 +545,13 @@ func (m *claudeLLMManager) HasModel(modelID string) bool { return modelID == "claude" || modelID == "claude-haiku-4.5" } +func (m *claudeLLMManager) GetModelInfo(modelID string) *models.ModelInfo { + if modelID == "claude-haiku-4.5" { + return &models.ModelInfo{DisplayName: "Claude Haiku", Tags: "slug"} + } + return nil +} + // TestClaudeCancelDuringToolCall tests cancellation during tool execution with Claude func TestClaudeCancelDuringToolCall(t *testing.T) { h := NewClaudeTestHarness(t) diff --git a/server/cancel_test.go b/server/cancel_test.go index 73fda6b90c05e8d3ef13c98b565d2804a2e91603..6917229091fe93cfb1e4a428f3ce69f710e5de5a 100644 --- a/server/cancel_test.go +++ b/server/cancel_test.go @@ -15,6 +15,7 @@ import ( "shelley.exe.dev/db/generated" "shelley.exe.dev/llm" "shelley.exe.dev/loop" + "shelley.exe.dev/models" ) // setupTestDB creates a test database @@ -374,3 +375,7 @@ func (m *testLLMManager) GetAvailableModels() []string { func (m *testLLMManager) HasModel(modelID string) bool { return modelID == "predictable" } + +func (m *testLLMManager) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} diff --git a/server/custom_models.go b/server/custom_models.go new file mode 100644 index 0000000000000000000000000000000000000000..03116c6b9fbf2226462a7ce36949ba57cfe93c50 --- /dev/null +++ b/server/custom_models.go @@ -0,0 +1,404 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "shelley.exe.dev/db/generated" + "shelley.exe.dev/llm" + "shelley.exe.dev/llm/ant" + "shelley.exe.dev/llm/gem" + "shelley.exe.dev/llm/oai" +) + +// ModelAPI is the API representation of a model +type ModelAPI struct { + ModelID string `json:"model_id"` + DisplayName string `json:"display_name"` + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + APIKey string `json:"api_key"` + ModelName string `json:"model_name"` + MaxTokens int64 `json:"max_tokens"` + Tags string `json:"tags"` // Comma-separated tags (e.g., "slug" for slug generation) +} + +// CreateModelRequest is the request body for creating a model +type CreateModelRequest struct { + DisplayName string `json:"display_name"` + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + APIKey string `json:"api_key"` + ModelName string `json:"model_name"` + MaxTokens int64 `json:"max_tokens"` + Tags string `json:"tags"` // Comma-separated tags +} + +// UpdateModelRequest is the request body for updating a model +type UpdateModelRequest struct { + DisplayName string `json:"display_name"` + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + APIKey string `json:"api_key"` // Empty string means keep existing + ModelName string `json:"model_name"` + MaxTokens int64 `json:"max_tokens"` + Tags string `json:"tags"` // Comma-separated tags +} + +// TestModelRequest is the request body for testing a model +type TestModelRequest struct { + ModelID string `json:"model_id,omitempty"` // If provided, use stored API key + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + APIKey string `json:"api_key"` + ModelName string `json:"model_name"` +} + +func toModelAPI(m generated.Model) ModelAPI { + return ModelAPI{ + ModelID: m.ModelID, + DisplayName: m.DisplayName, + ProviderType: m.ProviderType, + Endpoint: m.Endpoint, + APIKey: m.ApiKey, + ModelName: m.ModelName, + MaxTokens: m.MaxTokens, + Tags: m.Tags, + } +} + +func (s *Server) handleCustomModels(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + s.handleListModels(w, r) + case http.MethodPost: + s.handleCreateModel(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { + models, err := s.db.GetModels(r.Context()) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to get models: %v", err), http.StatusInternalServerError) + return + } + + apiModels := make([]ModelAPI, len(models)) + for i, m := range models { + apiModels[i] = toModelAPI(m) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(apiModels) +} + +func (s *Server) handleCreateModel(w http.ResponseWriter, r *http.Request) { + var req CreateModelRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.DisplayName == "" || req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" { + http.Error(w, "display_name, provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest) + return + } + + // Validate provider type + if req.ProviderType != "anthropic" && req.ProviderType != "openai" && req.ProviderType != "openai-responses" && req.ProviderType != "gemini" { + http.Error(w, "provider_type must be 'anthropic', 'openai', 'openai-responses', or 'gemini'", http.StatusBadRequest) + return + } + + // Generate model ID + modelID := "custom-" + uuid.New().String()[:8] + + // Default max tokens + if req.MaxTokens <= 0 { + req.MaxTokens = 200000 + } + + model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{ + ModelID: modelID, + DisplayName: req.DisplayName, + ProviderType: req.ProviderType, + Endpoint: req.Endpoint, + ApiKey: req.APIKey, + ModelName: req.ModelName, + MaxTokens: req.MaxTokens, + Tags: req.Tags, + }) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create model: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(toModelAPI(*model)) +} + +func (s *Server) handleCustomModel(w http.ResponseWriter, r *http.Request) { + // Extract model ID from URL path: /api/custom-models/{id} or /api/custom-models/{id}/duplicate + path := strings.TrimPrefix(r.URL.Path, "/api/custom-models/") + if path == "" { + http.Error(w, "Invalid model ID", http.StatusBadRequest) + return + } + + // Check for /duplicate suffix + if strings.HasSuffix(path, "/duplicate") { + modelID := strings.TrimSuffix(path, "/duplicate") + if r.Method == http.MethodPost { + s.handleDuplicateModel(w, r, modelID) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + return + } + + if strings.Contains(path, "/") { + http.Error(w, "Invalid model ID", http.StatusBadRequest) + return + } + modelID := path + + switch r.Method { + case http.MethodGet: + s.handleGetModel(w, r, modelID) + case http.MethodPut: + s.handleUpdateModel(w, r, modelID) + case http.MethodDelete: + s.handleDeleteModel(w, r, modelID) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleGetModel(w http.ResponseWriter, r *http.Request, modelID string) { + model, err := s.db.GetModel(r.Context(), modelID) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to get model: %v", err), http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toModelAPI(*model)) +} + +func (s *Server) handleUpdateModel(w http.ResponseWriter, r *http.Request, modelID string) { + // First, get the existing model to get the current API key if not provided + existing, err := s.db.GetModel(r.Context(), modelID) + if err != nil { + http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound) + return + } + + var req UpdateModelRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Use existing API key if not provided + apiKey := req.APIKey + if apiKey == "" { + apiKey = existing.ApiKey + } + + // Default max tokens + if req.MaxTokens <= 0 { + req.MaxTokens = 200000 + } + + model, err := s.db.UpdateModel(r.Context(), generated.UpdateModelParams{ + DisplayName: req.DisplayName, + ProviderType: req.ProviderType, + Endpoint: req.Endpoint, + ApiKey: apiKey, + ModelName: req.ModelName, + MaxTokens: req.MaxTokens, + Tags: req.Tags, + ModelID: modelID, + }) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to update model: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toModelAPI(*model)) +} + +func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request, modelID string) { + err := s.db.DeleteModel(r.Context(), modelID) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to delete model: %v", err), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// DuplicateModelRequest allows overriding fields when duplicating +type DuplicateModelRequest struct { + DisplayName string `json:"display_name,omitempty"` +} + +func (s *Server) handleDuplicateModel(w http.ResponseWriter, r *http.Request, modelID string) { + // Get the source model (including API key) + source, err := s.db.GetModel(r.Context(), modelID) + if err != nil { + http.Error(w, fmt.Sprintf("Source model not found: %v", err), http.StatusNotFound) + return + } + + // Parse optional overrides + var req DuplicateModelRequest + if r.Body != nil { + json.NewDecoder(r.Body).Decode(&req) // Ignore errors - all fields optional + } + + // Generate new model ID + newModelID := "custom-" + uuid.New().String()[:8] + + // Use provided display name or generate one + displayName := req.DisplayName + if displayName == "" { + displayName = source.DisplayName + " (copy)" + } + + // Create the duplicate with the same API key + model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{ + ModelID: newModelID, + DisplayName: displayName, + ProviderType: source.ProviderType, + Endpoint: source.Endpoint, + ApiKey: source.ApiKey, // Copy the API key! + ModelName: source.ModelName, + MaxTokens: source.MaxTokens, + Tags: "", // Don't copy tags + }) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to duplicate model: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(toModelAPI(*model)) +} + +func (s *Server) handleTestModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req TestModelRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // If model_id is provided and api_key is empty, look up the stored key + if req.ModelID != "" && req.APIKey == "" { + model, err := s.db.GetModel(r.Context(), req.ModelID) + if err != nil { + http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound) + return + } + req.APIKey = model.ApiKey + } + + if req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" { + http.Error(w, "provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest) + return + } + + // Create the appropriate service based on provider type + var service llm.Service + switch req.ProviderType { + case "anthropic": + service = &ant.Service{ + APIKey: req.APIKey, + URL: req.Endpoint, + Model: req.ModelName, + } + case "openai": + service = &oai.Service{ + APIKey: req.APIKey, + ModelURL: req.Endpoint, + Model: oai.Model{ + ModelName: req.ModelName, + URL: req.Endpoint, + }, + } + case "gemini": + service = &gem.Service{ + APIKey: req.APIKey, + URL: req.Endpoint, + Model: req.ModelName, + } + case "openai-responses": + service = &oai.ResponsesService{ + APIKey: req.APIKey, + Model: oai.Model{ + ModelName: req.ModelName, + URL: req.Endpoint, + }, + } + default: + http.Error(w, "Invalid provider_type", http.StatusBadRequest) + return + } + + // Send a simple test request + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + request := &llm.Request{ + Messages: []llm.Message{ + { + Role: llm.MessageRoleUser, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Say 'test successful' in exactly two words."}, + }, + }, + }, + } + + response, err := service.Do(ctx, request) + if err != nil { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "message": fmt.Sprintf("Test failed: %v", err), + }) + return + } + + // Check if we got a response + if len(response.Content) == 0 || response.Content[0].Text == "" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "message": "Test failed: empty response from model", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "message": fmt.Sprintf("Test successful! Response: %s", response.Content[0].Text), + }) +} diff --git a/server/debug_handlers.go b/server/debug_handlers.go index 3ff6e156cf81651209d1a66e92f2d3b775b613e4..00dceea25f620bc0b273804fd281af5548d387b6 100644 --- a/server/debug_handlers.go +++ b/server/debug_handlers.go @@ -182,6 +182,8 @@ tr:hover { background: #252525; } } .tab-content { display: none; } .tab-content.active { display: block; } +.model-display { color: #a5d6ff; } +.model-id { color: #888; font-size: 11px; } @@ -228,6 +230,13 @@ function formatDuration(ms) { return (ms / 1000).toFixed(2) + 's'; } +function formatModel(model, displayName) { + if (displayName) { + return '' + displayName + ' (' + model + ')'; + } + return model; +} + function syntaxHighlight(json) { if (typeof json !== 'string') json = JSON.stringify(json, null, 2); json = json.replace(/&/g, '&').replace(//g, '>'); @@ -254,7 +263,7 @@ async function loadRequests() { const data = await resp.json(); renderTable(data); } catch (e) { - document.getElementById('requests-body').innerHTML = + document.getElementById('requests-body').innerHTML = 'Error loading requests: ' + e.message + ''; } } @@ -269,20 +278,20 @@ function renderTable(requests) { for (const req of requests) { const tr = document.createElement('tr'); tr.id = 'row-' + req.id; - - const statusClass = req.status_code && req.status_code >= 200 && req.status_code < 300 ? 'success' : + + const statusClass = req.status_code && req.status_code >= 200 && req.status_code < 300 ? 'success' : (req.status_code ? 'error' : ''); - + let prefixInfo = '-'; if (req.prefix_request_id) { - prefixInfo = 'prefix from #' + req.prefix_request_id + + prefixInfo = 'prefix from #' + req.prefix_request_id + ' (' + formatSize(req.prefix_length) + ')'; } - + tr.innerHTML = ` + "`" + ` ${req.id} ${formatDate(req.created_at)} - ${req.model} + ${formatModel(req.model, req.model_display_name)} ${req.provider} ${req.status_code || '-'}${req.error ? ' ⚠' : ''} ${formatDuration(req.duration_ms)} @@ -302,7 +311,7 @@ async function toggleExpand(id) { expandedRows.delete(id); return; } - + expandedRows.add(id); const row = document.getElementById('row-' + id); const expandRow = document.createElement('tr'); @@ -325,7 +334,7 @@ async function toggleExpand(id) { ` + "`" + `; row.after(expandRow); - + // Load request body loadBody(id, 'request'); } @@ -336,9 +345,9 @@ async function loadBody(id, type) { renderBody(id, type, loadedData[key]); return; } - + try { - const url = type === 'request' + const url = type === 'request' ? '/debug/llm_requests/' + id + '/request' : '/debug/llm_requests/' + id + '/response'; const resp = await fetch(url); @@ -363,13 +372,13 @@ async function loadBody(id, type) { function renderBody(id, type, data) { const container = document.querySelector('#tab-' + type + '-' + id + ' pre'); if (!container) return; - + if (data === null) { container.className = ''; container.textContent = '(empty)'; return; } - + container.className = ''; if (typeof data === 'object') { container.innerHTML = syntaxHighlight(JSON.stringify(data, null, 2)); @@ -382,14 +391,14 @@ function showTab(id, tab) { // Update tab buttons const expandRow = document.getElementById('expand-' + id); if (!expandRow) return; - + expandRow.querySelectorAll('.tab-btn').forEach(btn => { btn.classList.remove('active'); if (btn.textContent.toLowerCase() === tab) { btn.classList.add('active'); } }); - + // Update tab content expandRow.querySelectorAll('.tab-content').forEach(content => { content.classList.remove('active'); diff --git a/server/handlers.go b/server/handlers.go index 2d6cd14281f8bb6df4a677624dd184ea460e1bb8..b6cfa46e03137f5e043a3832983d4499cddaf2a9 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -357,30 +357,7 @@ func (s *Server) serveIndexWithInit(w http.ResponseWriter, r *http.Request, fs h } // Build initialization data - type ModelInfo struct { - ID string `json:"id"` - Ready bool `json:"ready"` - MaxContextTokens int `json:"max_context_tokens,omitempty"` - } - - var modelList []ModelInfo - if s.predictableOnly { - modelList = append(modelList, ModelInfo{ID: "predictable", Ready: true, MaxContextTokens: 200000}) - } else { - modelIDs := s.llmManager.GetAvailableModels() - for _, id := range modelIDs { - // Skip predictable model unless predictable-only flag is set - if id == "predictable" { - continue - } - svc, err := s.llmManager.GetService(id) - maxCtx := 0 - if err == nil && svc != nil { - maxCtx = svc.TokenContextWindow() - } - modelList = append(modelList, ModelInfo{ID: id, Ready: err == nil, MaxContextTokens: maxCtx}) - } - } + modelList := s.getModelList() // Select default model - use configured default if available, otherwise first ready model defaultModel := s.defaultModel @@ -913,6 +890,52 @@ func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(version.GetInfo()) } +// ModelInfo represents a model in the API response +type ModelInfo struct { + ID string `json:"id"` + DisplayName string `json:"display_name,omitempty"` + Ready bool `json:"ready"` + MaxContextTokens int `json:"max_context_tokens,omitempty"` +} + +// getModelList returns the list of available models +func (s *Server) getModelList() []ModelInfo { + var modelList []ModelInfo + if s.predictableOnly { + modelList = append(modelList, ModelInfo{ID: "predictable", Ready: true, MaxContextTokens: 200000}) + } else { + modelIDs := s.llmManager.GetAvailableModels() + for _, id := range modelIDs { + // Skip predictable model unless predictable-only flag is set + if id == "predictable" { + continue + } + svc, err := s.llmManager.GetService(id) + maxCtx := 0 + if err == nil && svc != nil { + maxCtx = svc.TokenContextWindow() + } + info := ModelInfo{ID: id, Ready: err == nil, MaxContextTokens: maxCtx} + // Add display name from model info + if modelInfo := s.llmManager.GetModelInfo(id); modelInfo != nil { + info.DisplayName = modelInfo.DisplayName + } + modelList = append(modelList, info) + } + } + return modelList +} + +// handleModels returns the list of available models +func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(s.getModelList()) +} + // handleArchivedConversations handles GET /api/conversations/archived func (s *Server) handleArchivedConversations(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { diff --git a/server/server.go b/server/server.go index d2fe39610827e4baa65ea103e4591781aafba40b..e96f34fb95651f1dc7ef7951c29a6a06817df167 100644 --- a/server/server.go +++ b/server/server.go @@ -69,6 +69,7 @@ type LLMProvider interface { GetService(modelID string) (llm.Service, error) GetAvailableModels() []string HasModel(modelID string) bool + GetModelInfo(modelID string) *models.ModelInfo } // NewLLMServiceManager creates a new LLM service manager from config @@ -261,6 +262,14 @@ func (s *Server) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/api/read", s.handleRead) // Serves images mux.Handle("/api/write-file", http.HandlerFunc(s.handleWriteFile)) // Small response + // Custom models API + mux.Handle("/api/custom-models", http.HandlerFunc(s.handleCustomModels)) + mux.Handle("/api/custom-models/", http.HandlerFunc(s.handleCustomModel)) + mux.Handle("/api/custom-models-test", http.HandlerFunc(s.handleTestModel)) + + // Models API (dynamic list refresh) + mux.Handle("/api/models", http.HandlerFunc(s.handleModels)) + // Version endpoints mux.Handle("GET /version", http.HandlerFunc(s.handleVersion)) mux.Handle("GET /version-check", http.HandlerFunc(s.handleVersionCheck)) diff --git a/slug/slug.go b/slug/slug.go index c3258786113010c4d0b302b5391987499ac42f4a..87b27d879e3e4025a00449a5b922604f4c9ed298 100644 --- a/slug/slug.go +++ b/slug/slug.go @@ -10,15 +10,18 @@ import ( "shelley.exe.dev/db" "shelley.exe.dev/llm" + "shelley.exe.dev/models" ) // LLMServiceProvider defines the interface for getting LLM services type LLMServiceProvider interface { GetService(modelID string) (llm.Service, error) + GetAvailableModels() []string + GetModelInfo(modelID string) *models.ModelInfo } // GenerateSlug generates a slug for a conversation and updates the database -// If conversationModelID is provided, it will try to use that model first before falling back to the default list +// If conversationModelID is provided, it will be used as a fallback if no model is tagged with "slug" func GenerateSlug(ctx context.Context, llmProvider LLMServiceProvider, database *db.DB, logger *slog.Logger, conversationID, userMessage, conversationModelID string) (string, error) { baseSlug, err := generateSlugText(ctx, llmProvider, logger, userMessage, conversationModelID) if err != nil { @@ -53,15 +56,14 @@ func GenerateSlug(ctx context.Context, llmProvider LLMServiceProvider, database } // generateSlugText generates a human-readable slug for a conversation based on the user message -// If conversationModelID is "predictable", it will be used instead of the default preferred models +// Priority order: +// 1. If conversationModelID is "predictable", use it +// 2. Try models tagged with "slug" +// 3. Fall back to the conversation's model (conversationModelID) func generateSlugText(ctx context.Context, llmProvider LLMServiceProvider, logger *slog.Logger, userMessage, conversationModelID string) (string, error) { - // Try different models in order of preference var llmService llm.Service var err error - // Preferred models in order of preference - preferredModels := []string{"qwen3-coder-fireworks", "gpt5-mini", "gpt-5-thinking-mini", "claude-sonnet-4.5", "predictable"} - // If conversation is using predictable model, use it for slug generation too if conversationModelID == "predictable" { llmService, err = llmProvider.GetService("predictable") @@ -72,15 +74,29 @@ func generateSlugText(ctx context.Context, llmProvider LLMServiceProvider, logge } } - // If we didn't get the predictable service, try the preferred models + // Try models tagged with "slug" if llmService == nil { - for _, model := range preferredModels { - llmService, err = llmProvider.GetService(model) - if err == nil { - logger.Debug("Using preferred model for slug generation", "model", model) + for _, modelID := range llmProvider.GetAvailableModels() { + info := llmProvider.GetModelInfo(modelID) + if info != nil && strings.Contains(info.Tags, "slug") { + llmService, err = llmProvider.GetService(modelID) + if err == nil { + logger.Debug("Using slug-tagged model for slug generation", "model", modelID) + } else { + logger.Debug("Failed to get slug-tagged model", "model", modelID, "error", err) + } break } - logger.Debug("Model not available for slug generation", "model", model, "error", err) + } + } + + // Fall back to the conversation's model + if llmService == nil && conversationModelID != "" && conversationModelID != "predictable" { + llmService, err = llmProvider.GetService(conversationModelID) + if err == nil { + logger.Debug("Using conversation model for slug generation", "model", conversationModelID) + } else { + logger.Debug("Conversation model not available for slug generation", "model", conversationModelID, "error", err) } } @@ -134,9 +150,6 @@ Respond with only the slug, nothing else.`, userMessage) return "", fmt.Errorf("generated slug is empty after sanitization") } - // Note: We don't check for uniqueness here since we're generating for a new conversation - // and the database will handle any conflicts - return slug, nil } diff --git a/slug/slug_test.go b/slug/slug_test.go index 8b6baf814a5368d3b46ecee0262daefec6d67d13..6ef78739742092889950714aa9b9fe3020b87d7a 100644 --- a/slug/slug_test.go +++ b/slug/slug_test.go @@ -9,6 +9,7 @@ import ( "shelley.exe.dev/db" "shelley.exe.dev/llm" + "shelley.exe.dev/models" ) func TestSanitize(t *testing.T) { @@ -100,6 +101,14 @@ func (m *MockLLMProvider) GetService(modelID string) (llm.Service, error) { return m.Service, nil } +func (m *MockLLMProvider) GetAvailableModels() []string { + return []string{"mock"} +} + +func (m *MockLLMProvider) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} + // TestGenerateSlug_DatabaseIntegration tests slug generation with actual database conflicts func TestGenerateSlug_DatabaseIntegration(t *testing.T) { // Create temporary database @@ -135,7 +144,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) { } // Generate first slug - should succeed with "test-slug" - slug1, err := GenerateSlug(ctx, mockLLM, database, logger, conv1.ConversationID, "Test message", "") + slug1, err := GenerateSlug(ctx, mockLLM, database, logger, conv1.ConversationID, "Test message", "test-model") if err != nil { t.Fatalf("Failed to generate first slug: %v", err) } @@ -150,7 +159,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) { } // Generate second slug - should get "test-slug-1" due to conflict - slug2, err := GenerateSlug(ctx, mockLLM, database, logger, conv2.ConversationID, "Test message", "") + slug2, err := GenerateSlug(ctx, mockLLM, database, logger, conv2.ConversationID, "Test message", "test-model") if err != nil { t.Fatalf("Failed to generate second slug: %v", err) } @@ -165,7 +174,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) { } // Generate third slug - should get "test-slug-2" due to conflict - slug3, err := GenerateSlug(ctx, mockLLM, database, logger, conv3.ConversationID, "Test message", "") + slug3, err := GenerateSlug(ctx, mockLLM, database, logger, conv3.ConversationID, "Test message", "test-model") if err != nil { t.Fatalf("Failed to generate third slug: %v", err) } @@ -203,6 +212,14 @@ func (m *MockLLMProviderWithError) GetService(modelID string) (llm.Service, erro return nil, fmt.Errorf("model not available") } +func (m *MockLLMProviderWithError) GetAvailableModels() []string { + return []string{} +} + +func (m *MockLLMProviderWithError) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} + // MockLLMProviderWithServiceError provides a mock LLM provider that returns a service with error type MockLLMProviderWithServiceError struct{} @@ -210,6 +227,14 @@ func (m *MockLLMProviderWithServiceError) GetService(modelID string) (llm.Servic return &MockLLMServiceWithError{}, nil } +func (m *MockLLMProviderWithServiceError) GetAvailableModels() []string { + return []string{"mock"} +} + +func (m *MockLLMProviderWithServiceError) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} + // TestGenerateSlug_LLMError tests error handling when LLM service fails func TestGenerateSlug_LLMError(t *testing.T) { mockLLM := &MockLLMProviderWithServiceError{} @@ -218,8 +243,8 @@ func TestGenerateSlug_LLMError(t *testing.T) { Level: slog.LevelWarn, })) - // Test that LLM error is properly propagated - _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "") + // Test that LLM error is properly propagated (pass a model ID so we get a service) + _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model") if err == nil { t.Error("Expected error from LLM service, got nil") } @@ -255,7 +280,7 @@ func TestGenerateSlug_EmptyResponse(t *testing.T) { Level: slog.LevelWarn, })) - _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "") + _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model") if err == nil { t.Error("Expected error for empty LLM response, got nil") } @@ -271,6 +296,14 @@ func (m *MockLLMProviderWithEmptyResponse) GetService(modelID string) (llm.Servi return &MockLLMServiceEmptyResponse{}, nil } +func (m *MockLLMProviderWithEmptyResponse) GetAvailableModels() []string { + return []string{"mock"} +} + +func (m *MockLLMProviderWithEmptyResponse) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} + // MockLLMServiceEmptyResponse provides a mock LLM service that returns empty response type MockLLMServiceEmptyResponse struct{} @@ -301,7 +334,7 @@ func TestGenerateSlug_SanitizationError(t *testing.T) { Level: slog.LevelWarn, })) - _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "") + _, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "test-model") if err == nil { t.Error("Expected error for empty slug after sanitization, got nil") } @@ -357,7 +390,7 @@ func TestGenerateSlug_DatabaseError(t *testing.T) { } closedDB.Close() - _, err = GenerateSlug(ctx, mockLLM, closedDB, logger, "test-conversation-id", "Test message", "") + _, err = GenerateSlug(ctx, mockLLM, closedDB, logger, "test-conversation-id", "Test message", "test-model") if err == nil { t.Error("Expected database error, got nil") } @@ -386,9 +419,9 @@ func TestGenerateSlug_PredictableModel(t *testing.T) { } } -// TestGenerateSlug_PredictableModelFallback tests fallback when predictable model is not available -func TestGenerateSlug_PredictableModelFallback(t *testing.T) { - // Mock LLM provider that doesn't have predictable model but has other models +// TestGenerateSlug_ConversationModelFallback tests fallback to conversation model when no slug-tagged models exist +func TestGenerateSlug_ConversationModelFallback(t *testing.T) { + // Mock LLM provider that doesn't have predictable model but has a conversation model mockLLM := &MockLLMProviderPredictableFallback{ fallbackService: &MockLLMService{ ResponseText: "fallback-slug", @@ -399,10 +432,10 @@ func TestGenerateSlug_PredictableModelFallback(t *testing.T) { Level: slog.LevelDebug, })) - // Test that fallback to preferred models works when predictable is not available - slug, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "predictable") + // Test that fallback to conversation model works when no slug-tagged models exist + slug, err := generateSlugText(context.Background(), mockLLM, logger, "Test message", "my-custom-model") if err != nil { - t.Fatalf("Failed to generate slug with fallback: %v", err) + t.Fatalf("Failed to generate slug with conversation model fallback: %v", err) } if slug != "fallback-slug" { t.Errorf("Expected 'fallback-slug', got %q", slug) @@ -420,3 +453,11 @@ func (m *MockLLMProviderPredictableFallback) GetService(modelID string) (llm.Ser } return m.fallbackService, nil } + +func (m *MockLLMProviderPredictableFallback) GetAvailableModels() []string { + return []string{"my-custom-model"} +} + +func (m *MockLLMProviderPredictableFallback) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} diff --git a/test/server_test.go b/test/server_test.go index da9de5678f4371e2f2fceb7a8e1bf057944f6d81..2fa0f7101d65459bc6822add6b2c22fb0f102099 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -23,6 +23,7 @@ import ( "shelley.exe.dev/db/generated" "shelley.exe.dev/llm" "shelley.exe.dev/loop" + "shelley.exe.dev/models" "shelley.exe.dev/server" "shelley.exe.dev/slug" ) @@ -907,6 +908,10 @@ func (m *inspectableLLMManager) HasModel(modelID string) bool { return modelID == "predictable" } +func (m *inspectableLLMManager) GetModelInfo(modelID string) *models.ModelInfo { + return nil +} + func TestVersionEndpoint(t *testing.T) { // Create temp DB-backed server ctx := context.Background() diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 5e7d38c697ab13e7ce5ab6cf913e6ff5db943110..955847977129e4273b3c4dbaf8698cdb82b01fe2 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -2,6 +2,7 @@ import React, { useState, useEffect, useCallback, useRef } from "react"; import ChatInterface from "./components/ChatInterface"; import ConversationDrawer from "./components/ConversationDrawer"; import CommandPalette from "./components/CommandPalette"; +import ModelsModal from "./components/ModelsModal"; import { Conversation, ConversationWithState, ConversationListUpdate } from "./types"; import { api } from "./services/api"; @@ -65,6 +66,8 @@ function App() { const [drawerCollapsed, setDrawerCollapsed] = useState(false); const [commandPaletteOpen, setCommandPaletteOpen] = useState(false); const [diffViewerTrigger, setDiffViewerTrigger] = useState(0); + const [modelsModalOpen, setModelsModalOpen] = useState(false); + const [modelsRefreshTrigger, setModelsRefreshTrigger] = useState(0); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const [subagentUpdate, setSubagentUpdate] = useState(null); @@ -348,6 +351,7 @@ function App() { isDrawerCollapsed={drawerCollapsed} onToggleDrawerCollapse={toggleDrawerCollapsed} openDiffViewerTrigger={diffViewerTrigger} + modelsRefreshTrigger={modelsRefreshTrigger} /> @@ -368,9 +372,19 @@ function App() { setDiffViewerTrigger((prev) => prev + 1); setCommandPaletteOpen(false); }} + onOpenModelsModal={() => { + setModelsModalOpen(true); + setCommandPaletteOpen(false); + }} hasCwd={!!(currentConversation?.cwd || mostRecentCwd)} /> + setModelsModalOpen(false)} + onModelsChanged={() => setModelsRefreshTrigger((prev) => prev + 1)} + /> + {/* Backdrop for mobile drawer */} {drawerOpen && (
setDrawerOpen(false)} /> diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index a2251a115671add51ca3515784d68391a741b154..8c26f0a4d6e452007462ae7ee1f73a5aba35239d 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -394,6 +394,7 @@ interface ChatInterfaceProps { isDrawerCollapsed?: boolean; onToggleDrawerCollapse?: () => void; openDiffViewerTrigger?: number; // increment to trigger opening diff viewer + modelsRefreshTrigger?: number; // increment to trigger models list refresh } function ChatInterface({ @@ -409,12 +410,15 @@ function ChatInterface({ isDrawerCollapsed, onToggleDrawerCollapse, openDiffViewerTrigger, + modelsRefreshTrigger, }: ChatInterfaceProps) { const [messages, setMessages] = useState([]); const [loading, setLoading] = useState(true); const [sending, setSending] = useState(false); const [error, setError] = useState(null); - const models = window.__SHELLEY_INIT__?.models || []; + const [models, setModels] = useState< + Array<{ id: string; display_name?: string; ready: boolean; max_context_tokens?: number }> + >(window.__SHELLEY_INIT__?.models || []); const [selectedModel, setSelectedModelState] = useState(() => { // First check localStorage for a sticky model preference const storedModel = localStorage.getItem("shelley_selected_model"); @@ -473,6 +477,26 @@ function ChatInterface({ setCwdInitialized(true); } }, [mostRecentCwd, cwdInitialized]); + + // Refresh models list when triggered (e.g., after custom model changes) or when starting new conversation + useEffect(() => { + // Skip on initial mount with trigger=0, but always refresh when starting a new conversation + if (modelsRefreshTrigger === undefined) return; + if (modelsRefreshTrigger === 0 && conversationId !== null) return; + api + .getModels() + .then((newModels) => { + setModels(newModels); + // Also update the global init data so other components see the change + if (window.__SHELLEY_INIT__) { + window.__SHELLEY_INIT__.models = newModels; + } + }) + .catch((err) => { + console.error("Failed to refresh models:", err); + }); + }, [modelsRefreshTrigger, conversationId]); + const [cwdError, setCwdError] = useState(null); const [editingModel, setEditingModel] = useState(false); const [showDirectoryPicker, setShowDirectoryPicker] = useState(false); @@ -1459,7 +1483,7 @@ function ChatInterface({ > {models.map((model) => ( ))} @@ -1469,7 +1493,7 @@ function ChatInterface({ onClick={() => setEditingModel(true)} disabled={sending} > - {selectedModel} + {models.find((m) => m.id === selectedModel)?.display_name || selectedModel} )}
diff --git a/ui/src/components/CommandPalette.tsx b/ui/src/components/CommandPalette.tsx index bf4a8953401531ba2c8a0f171117171c060ae8e1..b97310b810cf28e6841df585275cd2338dc9edc5 100644 --- a/ui/src/components/CommandPalette.tsx +++ b/ui/src/components/CommandPalette.tsx @@ -19,6 +19,7 @@ interface CommandPaletteProps { onNewConversation: () => void; onSelectConversation: (id: string) => void; onOpenDiffViewer: () => void; + onOpenModelsModal: () => void; hasCwd: boolean; } @@ -64,6 +65,7 @@ function CommandPalette({ onNewConversation, onSelectConversation, onOpenDiffViewer, + onOpenModelsModal, hasCwd, }: CommandPaletteProps) { const [query, setQuery] = useState(""); @@ -161,8 +163,46 @@ function CommandPalette({ }); } + items.push({ + id: "manage-models", + type: "action", + title: "Add/Remove Models/Keys", + subtitle: "Configure custom AI models and API keys", + icon: ( + + + + + ), + action: () => { + onOpenModelsModal(); + onClose(); + }, + keywords: [ + "model", + "key", + "api", + "configure", + "settings", + "anthropic", + "openai", + "gemini", + "custom", + ], + }); + return items; - }, [onNewConversation, onOpenDiffViewer, onClose, hasCwd]); + }, [onNewConversation, onOpenDiffViewer, onOpenModelsModal, onClose, hasCwd]); // Convert conversations to command items const conversationToItem = useCallback( diff --git a/ui/src/components/Modal.tsx b/ui/src/components/Modal.tsx index 9b48166a710ef3a5a524ada633a23ac1288bf145..0f2d11816fedafad9910591a33651c78d4ade1fc 100644 --- a/ui/src/components/Modal.tsx +++ b/ui/src/components/Modal.tsx @@ -4,10 +4,12 @@ interface ModalProps { isOpen: boolean; onClose: () => void; title: string; + titleRight?: React.ReactNode; children: React.ReactNode; + className?: string; } -function Modal({ isOpen, onClose, title, children }: ModalProps) { +function Modal({ isOpen, onClose, title, titleRight, children, className }: ModalProps) { if (!isOpen) return null; const handleBackdropClick = (e: React.MouseEvent) => { @@ -18,10 +20,11 @@ function Modal({ isOpen, onClose, title, children }: ModalProps) { return (
-
+
{/* Header */}

{title}

+ {titleRight &&
{titleRight}
} +
+ )} + + {loading ? ( +
+
+ Loading models... +
+ ) : showForm ? ( + // Add/Edit form +
+

{editingModelId ? "Edit Model" : "Add Model"}

+ + {/* Provider Selection */} +
+ +
+ {(["anthropic", "openai", "openai-responses", "gemini"] as ProviderType[]).map( + (p) => ( + + ), + )} +
+
+ + {/* Endpoint Selection */} +
+ +
+ + +
+ {form.endpoint_custom ? ( + setForm((prev) => ({ ...prev, endpoint: e.target.value }))} + placeholder="https://..." + className="form-input" + /> + ) : ( +
{form.endpoint}
+ )} +
+ + {/* Model Name with Presets */} +
+ +
+ {DEFAULT_MODELS[form.provider_type].map((preset) => ( + + ))} +
+ setForm((prev) => ({ ...prev, model_name: e.target.value }))} + placeholder="Model name (e.g., claude-sonnet-4-5)" + className="form-input" + /> +
+ + {/* Display Name */} +
+ + setForm((prev) => ({ ...prev, display_name: e.target.value }))} + placeholder="Name shown in model selector" + className="form-input" + /> +
+ + {/* API Key */} +
+ + setForm((prev) => ({ ...prev, api_key: e.target.value }))} + placeholder="Enter API key" + className="form-input" + autoComplete="off" + /> +
+ + {/* Max Tokens */} +
+ + + setForm((prev) => ({ ...prev, max_tokens: parseInt(e.target.value) || 200000 })) + } + className="form-input" + /> +
+ + {/* Tags */} +
+ + setForm((prev) => ({ ...prev, tags: e.target.value }))} + placeholder="comma-separated, e.g., slug, cheap" + className="form-input" + /> +
+ + {/* Test Result */} + {testResult && ( +
+ {testResult.success ? "✓" : "✗"} {testResult.message} +
+ )} + + {/* Form Actions */} +
+ + + +
+
+ ) : ( + // Model List + <> + {models.length === 0 && + builtInModels.length > 0 && + builtInModels[0] !== "predictable" && ( +
+

Built-in models available:

+
    + {builtInModels + .filter((m) => m !== "predictable") + .map((model) => ( +
  • {model}
  • + ))} +
+
+ )} + + {models.length > 0 && ( +
+ {models.map((model) => ( +
+
+
+ {model.display_name} + + {PROVIDER_LABELS[model.provider_type]} + + {model.tags && ( + + {model.tags.split(",")[0]} + + )} +
+
+ + + +
+
+
+ {model.model_name} + {model.endpoint} +
+
+ ))} +
+ )} + + )} +
+ + ); +} + +export default ModelsModal; diff --git a/ui/src/services/api.ts b/ui/src/services/api.ts index ed2d74b35f3edfcdba0123de11f310439c4aa9b7..8f3c6519d7e7b94a9731f18631a4a6787ba5a971 100644 --- a/ui/src/services/api.ts +++ b/ui/src/services/api.ts @@ -27,6 +27,16 @@ class ApiService { return response.json(); } + async getModels(): Promise< + Array<{ id: string; display_name?: string; ready: boolean; max_context_tokens?: number }> + > { + const response = await fetch(`${this.baseUrl}/models`); + if (!response.ok) { + throw new Error(`Failed to get models: ${response.statusText}`); + } + return response.json(); + } + async searchConversations(query: string): Promise { const params = new URLSearchParams({ q: query, @@ -255,3 +265,115 @@ class ApiService { } export const api = new ApiService(); + +// Custom models API +export interface CustomModel { + model_id: string; + display_name: string; + provider_type: "anthropic" | "openai" | "openai-responses" | "gemini"; + endpoint: string; + api_key: string; + model_name: string; + max_tokens: number; + tags: string; // Comma-separated tags (e.g., "slug" for slug generation) +} + +export interface CreateCustomModelRequest { + display_name: string; + provider_type: "anthropic" | "openai" | "openai-responses" | "gemini"; + endpoint: string; + api_key: string; + model_name: string; + max_tokens: number; + tags: string; // Comma-separated tags +} + +export interface TestCustomModelRequest { + model_id?: string; // If provided with empty api_key, use stored key + provider_type: "anthropic" | "openai" | "openai-responses" | "gemini"; + endpoint: string; + api_key: string; + model_name: string; +} + +class CustomModelsApi { + private baseUrl = "/api"; + + private postHeaders = { + "Content-Type": "application/json", + "X-Shelley-Request": "1", + }; + + async getCustomModels(): Promise { + const response = await fetch(`${this.baseUrl}/custom-models`); + if (!response.ok) { + throw new Error(`Failed to get custom models: ${response.statusText}`); + } + return response.json(); + } + + async createCustomModel(request: CreateCustomModelRequest): Promise { + const response = await fetch(`${this.baseUrl}/custom-models`, { + method: "POST", + headers: this.postHeaders, + body: JSON.stringify(request), + }); + if (!response.ok) { + throw new Error(`Failed to create custom model: ${response.statusText}`); + } + return response.json(); + } + + async updateCustomModel( + modelId: string, + request: Partial, + ): Promise { + const response = await fetch(`${this.baseUrl}/custom-models/${modelId}`, { + method: "PUT", + headers: this.postHeaders, + body: JSON.stringify(request), + }); + if (!response.ok) { + throw new Error(`Failed to update custom model: ${response.statusText}`); + } + return response.json(); + } + + async deleteCustomModel(modelId: string): Promise { + const response = await fetch(`${this.baseUrl}/custom-models/${modelId}`, { + method: "DELETE", + headers: { "X-Shelley-Request": "1" }, + }); + if (!response.ok) { + throw new Error(`Failed to delete custom model: ${response.statusText}`); + } + } + + async duplicateCustomModel(modelId: string, displayName?: string): Promise { + const response = await fetch(`${this.baseUrl}/custom-models/${modelId}/duplicate`, { + method: "POST", + headers: this.postHeaders, + body: JSON.stringify({ display_name: displayName }), + }); + if (!response.ok) { + throw new Error(`Failed to duplicate custom model: ${response.statusText}`); + } + return response.json(); + } + + async testCustomModel( + request: TestCustomModelRequest, + ): Promise<{ success: boolean; message: string }> { + const response = await fetch(`${this.baseUrl}/custom-models-test`, { + method: "POST", + headers: this.postHeaders, + body: JSON.stringify(request), + }); + if (!response.ok) { + throw new Error(`Failed to test custom model: ${response.statusText}`); + } + return response.json(); + } +} + +export const customModelsApi = new CustomModelsApi(); diff --git a/ui/src/styles.css b/ui/src/styles.css index d145f1b3a339e5926d08e0eabdc415a572532ecb..b6e3ebaeda1f67aea18b7b1f03154a11d52270a5 100644 --- a/ui/src/styles.css +++ b/ui/src/styles.css @@ -306,6 +306,11 @@ button { background: var(--bg-tertiary); } +.btn-sm { + padding: 0.375rem 0.75rem; + font-size: 0.8125rem; +} + .btn-icon { padding: 0.5rem; border-radius: 0.375rem; @@ -1992,6 +1997,10 @@ button { overflow: hidden; } +.modal.modal-wide { + max-width: 40rem; +} + .modal-header { display: flex; align-items: center; @@ -2005,8 +2014,15 @@ button { font-weight: 600; } +.modal-title-right { + margin-left: auto; + margin-right: 1rem; +} + .modal-body { padding: 1rem; + overflow-y: auto; + max-height: calc(80vh - 60px); } /* Form Elements */ @@ -3795,6 +3811,158 @@ svg { } } +/* Models Modal Styles */ +.models-modal { + padding: 0.5rem; +} + +.models-loading { + display: flex; + align-items: center; + justify-content: center; + gap: 0.5rem; + padding: 2rem; + color: var(--text-secondary); +} + +.models-error { + display: flex; + align-items: center; + justify-content: space-between; + padding: 0.75rem 1rem; + background: var(--error-bg); + border: 1px solid var(--error-border); + border-radius: 0.375rem; + color: var(--error-text); + margin-bottom: 1rem; +} + +.models-error-dismiss { + background: none; + border: none; + font-size: 1.25rem; + cursor: pointer; + color: inherit; + padding: 0; + line-height: 1; +} + +.models-info { + padding: 0.75rem 1rem; + background: var(--bg-tertiary); + border-radius: 0.375rem; + margin-bottom: 1rem; +} + +.models-info p { + margin: 0 0 0.5rem 0; + color: var(--text-secondary); +} + +.builtin-list { + margin: 0; + padding-left: 1.25rem; + color: var(--text-primary); +} + +.builtin-list li { + margin: 0.25rem 0; +} + +.models-list { + display: flex; + flex-direction: column; + gap: 0.75rem; + margin-bottom: 1rem; +} + +.model-card { + padding: 0.75rem 1rem; + background: var(--bg-base); + border: 1px solid var(--border); + border-radius: 0.375rem; +} + +.model-header { + display: flex; + justify-content: space-between; + align-items: flex-start; + gap: 0.5rem; +} + +.model-info { + display: flex; + flex-wrap: wrap; + align-items: center; + gap: 0.5rem; +} + +.model-name { + font-weight: 500; + color: var(--text-primary); +} + +.model-provider { + font-size: 0.75rem; + padding: 0.125rem 0.5rem; + background: var(--bg-tertiary); + border-radius: 0.25rem; + color: var(--text-secondary); +} + +.model-badge { + font-size: 0.625rem; + text-transform: uppercase; + letter-spacing: 0.05em; + padding: 0.125rem 0.375rem; + background: var(--blue-bg); + border: 1px solid var(--blue-border); + border-radius: 0.25rem; + color: var(--blue-text); +} + +.model-actions { + display: flex; + gap: 0.25rem; + flex-shrink: 0; +} + +.model-details { + margin-top: 0.5rem; + display: flex; + flex-direction: column; + gap: 0.25rem; +} + +.model-api-name { + font-size: 0.75rem; + font-family: var(--font-mono); + color: var(--text-secondary); +} + +.model-endpoint { + font-size: 0.75rem; + color: var(--text-tertiary); + word-break: break-all; +} + +.btn-icon { + padding: 0.25rem; + border-radius: 0.25rem; + color: var(--text-secondary); + background: transparent; + border: none; + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; +} + +.btn-icon:hover { + background: var(--bg-tertiary); + color: var(--text-primary); +} + /* Version Checker Styles */ .version-update-dot { position: absolute; @@ -4018,3 +4186,243 @@ svg { .version-btn-primary:hover:not(:disabled) { background: var(--primary-dark); } + +.btn-icon.btn-danger:hover { + background: var(--error-bg); + color: var(--error-text); +} + +.add-model-btn { + width: 100%; +} + +/* Model Form */ +.model-form h3 { + margin: 0 0 0.75rem 0; + font-size: 1rem; + font-weight: 600; +} + +.form-group { + margin-bottom: 0.75rem; +} + +.form-group label { + display: block; + font-size: 0.875rem; + font-weight: 500; + margin-bottom: 0.25rem; + color: var(--text-primary); +} + +.form-group label .optional { + font-weight: 400; + color: var(--text-secondary); +} + +.form-input { + width: 100%; + padding: 0.5rem 0.75rem; + background: var(--bg-base); + border: 1px solid var(--border); + border-radius: 0.375rem; + color: var(--text-primary); + font-family: inherit; + font-size: 0.875rem; +} + +.form-input:focus { + outline: none; + border-color: var(--primary); + box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2); +} + +.form-checkbox { + display: flex; + align-items: center; +} + +.form-checkbox label { + display: flex; + align-items: center; + gap: 0.5rem; + cursor: pointer; + margin: 0; +} + +.form-checkbox input[type="checkbox"] { + width: 1rem; + height: 1rem; + cursor: pointer; +} + +.info-icon-wrapper { + position: relative; + display: inline-flex; + align-items: center; + margin-left: 0.25rem; +} + +.info-icon { + display: inline-flex; + align-items: center; + justify-content: center; + color: var(--text-tertiary); + cursor: pointer; +} + +.info-icon:hover { + color: var(--text-secondary); +} + +.info-tooltip { + position: absolute; + top: 50%; + left: 100%; + transform: translateY(-50%); + background: var(--bg-tertiary); + border: 1px solid var(--border); + border-radius: 0.375rem; + padding: 0.5rem 0.75rem; + font-size: 0.75rem; + font-weight: 400; + color: var(--text-primary); + white-space: normal; + width: 220px; + margin-left: 0.375rem; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); + z-index: 100; +} + +.info-tooltip::after { + content: ""; + position: absolute; + top: 50%; + right: 100%; + transform: translateY(-50%); + border: 6px solid transparent; + border-right-color: var(--border); +} + +.provider-buttons { + display: flex; + gap: 0.5rem; +} + +.provider-btn { + flex: 1; + padding: 0.5rem; + background: var(--bg-base); + border: 1px solid var(--border); + border-radius: 0.375rem; + color: var(--text-primary); + font-size: 0.875rem; + cursor: pointer; + transition: all 0.15s; +} + +.provider-btn:hover { + background: var(--bg-tertiary); +} + +.provider-btn.selected { + background: var(--primary); + border-color: var(--primary); + color: white; +} + +.endpoint-toggle { + display: flex; + gap: 0; + margin-bottom: 0.5rem; +} + +.toggle-btn { + flex: 1; + padding: 0.375rem 0.75rem; + background: var(--bg-base); + border: 1px solid var(--border); + color: var(--text-secondary); + font-size: 0.75rem; + cursor: pointer; + transition: all 0.15s; +} + +.toggle-btn:first-child { + border-radius: 0.375rem 0 0 0.375rem; +} + +.toggle-btn:last-child { + border-radius: 0 0.375rem 0.375rem 0; + border-left: none; +} + +.toggle-btn.selected { + background: var(--bg-tertiary); + color: var(--text-primary); +} + +.endpoint-display { + padding: 0.5rem 0.75rem; + background: var(--bg-tertiary); + border-radius: 0.375rem; + font-size: 0.75rem; + color: var(--text-secondary); + word-break: break-all; +} + +.model-presets { + display: flex; + flex-wrap: wrap; + gap: 0.375rem; + margin-bottom: 0.5rem; +} + +.preset-btn { + padding: 0.25rem 0.625rem; + background: var(--bg-base); + border: 1px solid var(--border); + border-radius: 0.25rem; + color: var(--text-secondary); + font-size: 0.75rem; + cursor: pointer; + transition: all 0.15s; +} + +.preset-btn:hover { + background: var(--bg-tertiary); + color: var(--text-primary); +} + +.preset-btn.selected { + background: var(--blue-bg); + border-color: var(--blue-border); + color: var(--blue-text); +} + +.test-result { + padding: 0.75rem 1rem; + border-radius: 0.375rem; + margin-bottom: 1rem; + font-size: 0.875rem; +} + +.test-result.success { + background: var(--success-bg); + border: 1px solid var(--success-border); + color: var(--success-text); +} + +.test-result.error { + background: var(--error-bg); + border: 1px solid var(--error-border); + color: var(--error-text); +} + +.form-actions { + display: flex; + gap: 0.5rem; + justify-content: flex-end; + margin-top: 1rem; + padding-bottom: 0.5rem; +} diff --git a/ui/src/types.ts b/ui/src/types.ts index 64a3401c4f93dbbfa5653d935f34b20354f73291..7010aa70e3aae865c6706cca719fcbcb88519eab 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -49,6 +49,7 @@ export interface LLMContent { // API types export interface Model { id: string; + display_name?: string; ready: boolean; max_context_tokens?: number; }