1package server
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "strings"
9 "time"
10
11 "github.com/google/uuid"
12 "shelley.exe.dev/db/generated"
13 "shelley.exe.dev/llm"
14 "shelley.exe.dev/llm/ant"
15 "shelley.exe.dev/llm/gem"
16 "shelley.exe.dev/llm/oai"
17)
18
19// ModelAPI is the API representation of a model
20type ModelAPI struct {
21 ModelID string `json:"model_id"`
22 DisplayName string `json:"display_name"`
23 ProviderType string `json:"provider_type"`
24 Endpoint string `json:"endpoint"`
25 APIKey string `json:"api_key"`
26 ModelName string `json:"model_name"`
27 MaxTokens int64 `json:"max_tokens"`
28 Tags string `json:"tags"` // Comma-separated tags (e.g., "slug" for slug generation)
29}
30
31// CreateModelRequest is the request body for creating a model
32type CreateModelRequest struct {
33 DisplayName string `json:"display_name"`
34 ProviderType string `json:"provider_type"`
35 Endpoint string `json:"endpoint"`
36 APIKey string `json:"api_key"`
37 ModelName string `json:"model_name"`
38 MaxTokens int64 `json:"max_tokens"`
39 Tags string `json:"tags"` // Comma-separated tags
40}
41
42// UpdateModelRequest is the request body for updating a model
43type UpdateModelRequest struct {
44 DisplayName string `json:"display_name"`
45 ProviderType string `json:"provider_type"`
46 Endpoint string `json:"endpoint"`
47 APIKey string `json:"api_key"` // Empty string means keep existing
48 ModelName string `json:"model_name"`
49 MaxTokens int64 `json:"max_tokens"`
50 Tags string `json:"tags"` // Comma-separated tags
51}
52
53// TestModelRequest is the request body for testing a model
54type TestModelRequest struct {
55 ModelID string `json:"model_id,omitempty"` // If provided, use stored API key
56 ProviderType string `json:"provider_type"`
57 Endpoint string `json:"endpoint"`
58 APIKey string `json:"api_key"`
59 ModelName string `json:"model_name"`
60}
61
62func toModelAPI(m generated.Model) ModelAPI {
63 return ModelAPI{
64 ModelID: m.ModelID,
65 DisplayName: m.DisplayName,
66 ProviderType: m.ProviderType,
67 Endpoint: m.Endpoint,
68 APIKey: m.ApiKey,
69 ModelName: m.ModelName,
70 MaxTokens: m.MaxTokens,
71 Tags: m.Tags,
72 }
73}
74
75func (s *Server) handleCustomModels(w http.ResponseWriter, r *http.Request) {
76 switch r.Method {
77 case http.MethodGet:
78 s.handleListModels(w, r)
79 case http.MethodPost:
80 s.handleCreateModel(w, r)
81 default:
82 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
83 }
84}
85
86func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
87 models, err := s.db.GetModels(r.Context())
88 if err != nil {
89 http.Error(w, fmt.Sprintf("Failed to get models: %v", err), http.StatusInternalServerError)
90 return
91 }
92
93 apiModels := make([]ModelAPI, len(models))
94 for i, m := range models {
95 apiModels[i] = toModelAPI(m)
96 }
97
98 w.Header().Set("Content-Type", "application/json")
99 json.NewEncoder(w).Encode(apiModels)
100}
101
102func (s *Server) handleCreateModel(w http.ResponseWriter, r *http.Request) {
103 var req CreateModelRequest
104 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
105 http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
106 return
107 }
108
109 // Validate required fields
110 if req.DisplayName == "" || req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
111 http.Error(w, "display_name, provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
112 return
113 }
114
115 // Validate provider type
116 if req.ProviderType != "anthropic" && req.ProviderType != "openai" && req.ProviderType != "openai-responses" && req.ProviderType != "gemini" {
117 http.Error(w, "provider_type must be 'anthropic', 'openai', 'openai-responses', or 'gemini'", http.StatusBadRequest)
118 return
119 }
120
121 // Generate model ID
122 modelID := "custom-" + uuid.New().String()[:8]
123
124 // Default max tokens
125 if req.MaxTokens <= 0 {
126 req.MaxTokens = 200000
127 }
128
129 model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
130 ModelID: modelID,
131 DisplayName: req.DisplayName,
132 ProviderType: req.ProviderType,
133 Endpoint: req.Endpoint,
134 ApiKey: req.APIKey,
135 ModelName: req.ModelName,
136 MaxTokens: req.MaxTokens,
137 Tags: req.Tags,
138 })
139 if err != nil {
140 http.Error(w, fmt.Sprintf("Failed to create model: %v", err), http.StatusInternalServerError)
141 return
142 }
143
144 // Refresh the model manager's cache
145 if err := s.llmManager.RefreshCustomModels(); err != nil {
146 s.logger.Warn("Failed to refresh custom models cache", "error", err)
147 }
148
149 w.Header().Set("Content-Type", "application/json")
150 w.WriteHeader(http.StatusCreated)
151 json.NewEncoder(w).Encode(toModelAPI(*model))
152}
153
154func (s *Server) handleCustomModel(w http.ResponseWriter, r *http.Request) {
155 // Extract model ID from URL path: /api/custom-models/{id} or /api/custom-models/{id}/duplicate
156 path := strings.TrimPrefix(r.URL.Path, "/api/custom-models/")
157 if path == "" {
158 http.Error(w, "Invalid model ID", http.StatusBadRequest)
159 return
160 }
161
162 // Check for /duplicate suffix
163 if strings.HasSuffix(path, "/duplicate") {
164 modelID := strings.TrimSuffix(path, "/duplicate")
165 if r.Method == http.MethodPost {
166 s.handleDuplicateModel(w, r, modelID)
167 } else {
168 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
169 }
170 return
171 }
172
173 if strings.Contains(path, "/") {
174 http.Error(w, "Invalid model ID", http.StatusBadRequest)
175 return
176 }
177 modelID := path
178
179 switch r.Method {
180 case http.MethodGet:
181 s.handleGetModel(w, r, modelID)
182 case http.MethodPut:
183 s.handleUpdateModel(w, r, modelID)
184 case http.MethodDelete:
185 s.handleDeleteModel(w, r, modelID)
186 default:
187 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
188 }
189}
190
191func (s *Server) handleGetModel(w http.ResponseWriter, r *http.Request, modelID string) {
192 model, err := s.db.GetModel(r.Context(), modelID)
193 if err != nil {
194 http.Error(w, fmt.Sprintf("Failed to get model: %v", err), http.StatusNotFound)
195 return
196 }
197
198 w.Header().Set("Content-Type", "application/json")
199 json.NewEncoder(w).Encode(toModelAPI(*model))
200}
201
202func (s *Server) handleUpdateModel(w http.ResponseWriter, r *http.Request, modelID string) {
203 // First, get the existing model to get the current API key if not provided
204 existing, err := s.db.GetModel(r.Context(), modelID)
205 if err != nil {
206 http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
207 return
208 }
209
210 var req UpdateModelRequest
211 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
212 http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
213 return
214 }
215
216 // Use existing API key if not provided
217 apiKey := req.APIKey
218 if apiKey == "" {
219 apiKey = existing.ApiKey
220 }
221
222 // Default max tokens
223 if req.MaxTokens <= 0 {
224 req.MaxTokens = 200000
225 }
226
227 model, err := s.db.UpdateModel(r.Context(), generated.UpdateModelParams{
228 DisplayName: req.DisplayName,
229 ProviderType: req.ProviderType,
230 Endpoint: req.Endpoint,
231 ApiKey: apiKey,
232 ModelName: req.ModelName,
233 MaxTokens: req.MaxTokens,
234 Tags: req.Tags,
235 ModelID: modelID,
236 })
237 if err != nil {
238 http.Error(w, fmt.Sprintf("Failed to update model: %v", err), http.StatusInternalServerError)
239 return
240 }
241
242 // Refresh the model manager's cache
243 if err := s.llmManager.RefreshCustomModels(); err != nil {
244 s.logger.Warn("Failed to refresh custom models cache", "error", err)
245 }
246
247 w.Header().Set("Content-Type", "application/json")
248 json.NewEncoder(w).Encode(toModelAPI(*model))
249}
250
251func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request, modelID string) {
252 err := s.db.DeleteModel(r.Context(), modelID)
253 if err != nil {
254 http.Error(w, fmt.Sprintf("Failed to delete model: %v", err), http.StatusInternalServerError)
255 return
256 }
257
258 // Refresh the model manager's cache
259 if err := s.llmManager.RefreshCustomModels(); err != nil {
260 s.logger.Warn("Failed to refresh custom models cache", "error", err)
261 }
262
263 w.WriteHeader(http.StatusNoContent)
264}
265
266// DuplicateModelRequest allows overriding fields when duplicating
267type DuplicateModelRequest struct {
268 DisplayName string `json:"display_name,omitempty"`
269}
270
271func (s *Server) handleDuplicateModel(w http.ResponseWriter, r *http.Request, modelID string) {
272 // Get the source model (including API key)
273 source, err := s.db.GetModel(r.Context(), modelID)
274 if err != nil {
275 http.Error(w, fmt.Sprintf("Source model not found: %v", err), http.StatusNotFound)
276 return
277 }
278
279 // Parse optional overrides
280 var req DuplicateModelRequest
281 if r.Body != nil {
282 json.NewDecoder(r.Body).Decode(&req) // Ignore errors - all fields optional
283 }
284
285 // Generate new model ID
286 newModelID := "custom-" + uuid.New().String()[:8]
287
288 // Use provided display name or generate one
289 displayName := req.DisplayName
290 if displayName == "" {
291 displayName = source.DisplayName + " (copy)"
292 }
293
294 // Create the duplicate with the same API key
295 model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
296 ModelID: newModelID,
297 DisplayName: displayName,
298 ProviderType: source.ProviderType,
299 Endpoint: source.Endpoint,
300 ApiKey: source.ApiKey, // Copy the API key!
301 ModelName: source.ModelName,
302 MaxTokens: source.MaxTokens,
303 Tags: "", // Don't copy tags
304 })
305 if err != nil {
306 http.Error(w, fmt.Sprintf("Failed to duplicate model: %v", err), http.StatusInternalServerError)
307 return
308 }
309
310 // Refresh the model manager's cache
311 if err := s.llmManager.RefreshCustomModels(); err != nil {
312 s.logger.Warn("Failed to refresh custom models cache", "error", err)
313 }
314
315 w.Header().Set("Content-Type", "application/json")
316 w.WriteHeader(http.StatusCreated)
317 json.NewEncoder(w).Encode(toModelAPI(*model))
318}
319
320func (s *Server) handleTestModel(w http.ResponseWriter, r *http.Request) {
321 if r.Method != http.MethodPost {
322 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
323 return
324 }
325
326 var req TestModelRequest
327 if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
328 http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
329 return
330 }
331
332 // If model_id is provided and api_key is empty, look up the stored key
333 if req.ModelID != "" && req.APIKey == "" {
334 model, err := s.db.GetModel(r.Context(), req.ModelID)
335 if err != nil {
336 http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
337 return
338 }
339 req.APIKey = model.ApiKey
340 }
341
342 if req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
343 http.Error(w, "provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
344 return
345 }
346
347 // Create the appropriate service based on provider type
348 var service llm.Service
349 switch req.ProviderType {
350 case "anthropic":
351 service = &ant.Service{
352 APIKey: req.APIKey,
353 URL: req.Endpoint,
354 Model: req.ModelName,
355 }
356 case "openai":
357 service = &oai.Service{
358 APIKey: req.APIKey,
359 ModelURL: req.Endpoint,
360 Model: oai.Model{
361 ModelName: req.ModelName,
362 URL: req.Endpoint,
363 },
364 }
365 case "gemini":
366 service = &gem.Service{
367 APIKey: req.APIKey,
368 URL: req.Endpoint,
369 Model: req.ModelName,
370 }
371 case "openai-responses":
372 service = &oai.ResponsesService{
373 APIKey: req.APIKey,
374 Model: oai.Model{
375 ModelName: req.ModelName,
376 URL: req.Endpoint,
377 },
378 }
379 default:
380 http.Error(w, "Invalid provider_type", http.StatusBadRequest)
381 return
382 }
383
384 // Send a simple test request
385 ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
386 defer cancel()
387
388 request := &llm.Request{
389 Messages: []llm.Message{
390 {
391 Role: llm.MessageRoleUser,
392 Content: []llm.Content{
393 {Type: llm.ContentTypeText, Text: "Say 'test successful' in exactly two words."},
394 },
395 },
396 },
397 }
398
399 response, err := service.Do(ctx, request)
400 if err != nil {
401 w.Header().Set("Content-Type", "application/json")
402 json.NewEncoder(w).Encode(map[string]interface{}{
403 "success": false,
404 "message": fmt.Sprintf("Test failed: %v", err),
405 })
406 return
407 }
408
409 // Check if we got a response
410 if len(response.Content) == 0 || response.Content[0].Text == "" {
411 w.Header().Set("Content-Type", "application/json")
412 json.NewEncoder(w).Encode(map[string]interface{}{
413 "success": false,
414 "message": "Test failed: empty response from model",
415 })
416 return
417 }
418
419 w.Header().Set("Content-Type", "application/json")
420 json.NewEncoder(w).Encode(map[string]interface{}{
421 "success": true,
422 "message": fmt.Sprintf("Test successful! Response: %s", response.Content[0].Text),
423 })
424}