custom_models.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"net/http"
  8	"strings"
  9	"time"
 10
 11	"github.com/google/uuid"
 12	"shelley.exe.dev/db/generated"
 13	"shelley.exe.dev/llm"
 14	"shelley.exe.dev/llm/ant"
 15	"shelley.exe.dev/llm/gem"
 16	"shelley.exe.dev/llm/oai"
 17)
 18
 19// ModelAPI is the API representation of a model
 20type ModelAPI struct {
 21	ModelID      string `json:"model_id"`
 22	DisplayName  string `json:"display_name"`
 23	ProviderType string `json:"provider_type"`
 24	Endpoint     string `json:"endpoint"`
 25	APIKey       string `json:"api_key"`
 26	ModelName    string `json:"model_name"`
 27	MaxTokens    int64  `json:"max_tokens"`
 28	Tags         string `json:"tags"` // Comma-separated tags (e.g., "slug" for slug generation)
 29}
 30
 31// CreateModelRequest is the request body for creating a model
 32type CreateModelRequest struct {
 33	DisplayName  string `json:"display_name"`
 34	ProviderType string `json:"provider_type"`
 35	Endpoint     string `json:"endpoint"`
 36	APIKey       string `json:"api_key"`
 37	ModelName    string `json:"model_name"`
 38	MaxTokens    int64  `json:"max_tokens"`
 39	Tags         string `json:"tags"` // Comma-separated tags
 40}
 41
 42// UpdateModelRequest is the request body for updating a model
 43type UpdateModelRequest struct {
 44	DisplayName  string `json:"display_name"`
 45	ProviderType string `json:"provider_type"`
 46	Endpoint     string `json:"endpoint"`
 47	APIKey       string `json:"api_key"` // Empty string means keep existing
 48	ModelName    string `json:"model_name"`
 49	MaxTokens    int64  `json:"max_tokens"`
 50	Tags         string `json:"tags"` // Comma-separated tags
 51}
 52
 53// TestModelRequest is the request body for testing a model
 54type TestModelRequest struct {
 55	ModelID      string `json:"model_id,omitempty"` // If provided, use stored API key
 56	ProviderType string `json:"provider_type"`
 57	Endpoint     string `json:"endpoint"`
 58	APIKey       string `json:"api_key"`
 59	ModelName    string `json:"model_name"`
 60}
 61
 62func toModelAPI(m generated.Model) ModelAPI {
 63	return ModelAPI{
 64		ModelID:      m.ModelID,
 65		DisplayName:  m.DisplayName,
 66		ProviderType: m.ProviderType,
 67		Endpoint:     m.Endpoint,
 68		APIKey:       m.ApiKey,
 69		ModelName:    m.ModelName,
 70		MaxTokens:    m.MaxTokens,
 71		Tags:         m.Tags,
 72	}
 73}
 74
 75func (s *Server) handleCustomModels(w http.ResponseWriter, r *http.Request) {
 76	switch r.Method {
 77	case http.MethodGet:
 78		s.handleListModels(w, r)
 79	case http.MethodPost:
 80		s.handleCreateModel(w, r)
 81	default:
 82		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 83	}
 84}
 85
 86func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
 87	models, err := s.db.GetModels(r.Context())
 88	if err != nil {
 89		http.Error(w, fmt.Sprintf("Failed to get models: %v", err), http.StatusInternalServerError)
 90		return
 91	}
 92
 93	apiModels := make([]ModelAPI, len(models))
 94	for i, m := range models {
 95		apiModels[i] = toModelAPI(m)
 96	}
 97
 98	w.Header().Set("Content-Type", "application/json")
 99	json.NewEncoder(w).Encode(apiModels)
100}
101
102func (s *Server) handleCreateModel(w http.ResponseWriter, r *http.Request) {
103	var req CreateModelRequest
104	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
105		http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
106		return
107	}
108
109	// Validate required fields
110	if req.DisplayName == "" || req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
111		http.Error(w, "display_name, provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
112		return
113	}
114
115	// Validate provider type
116	if req.ProviderType != "anthropic" && req.ProviderType != "openai" && req.ProviderType != "openai-responses" && req.ProviderType != "gemini" {
117		http.Error(w, "provider_type must be 'anthropic', 'openai', 'openai-responses', or 'gemini'", http.StatusBadRequest)
118		return
119	}
120
121	// Generate model ID
122	modelID := "custom-" + uuid.New().String()[:8]
123
124	// Default max tokens
125	if req.MaxTokens <= 0 {
126		req.MaxTokens = 200000
127	}
128
129	model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
130		ModelID:      modelID,
131		DisplayName:  req.DisplayName,
132		ProviderType: req.ProviderType,
133		Endpoint:     req.Endpoint,
134		ApiKey:       req.APIKey,
135		ModelName:    req.ModelName,
136		MaxTokens:    req.MaxTokens,
137		Tags:         req.Tags,
138	})
139	if err != nil {
140		http.Error(w, fmt.Sprintf("Failed to create model: %v", err), http.StatusInternalServerError)
141		return
142	}
143
144	// Refresh the model manager's cache
145	if err := s.llmManager.RefreshCustomModels(); err != nil {
146		s.logger.Warn("Failed to refresh custom models cache", "error", err)
147	}
148
149	w.Header().Set("Content-Type", "application/json")
150	w.WriteHeader(http.StatusCreated)
151	json.NewEncoder(w).Encode(toModelAPI(*model))
152}
153
154func (s *Server) handleCustomModel(w http.ResponseWriter, r *http.Request) {
155	// Extract model ID from URL path: /api/custom-models/{id} or /api/custom-models/{id}/duplicate
156	path := strings.TrimPrefix(r.URL.Path, "/api/custom-models/")
157	if path == "" {
158		http.Error(w, "Invalid model ID", http.StatusBadRequest)
159		return
160	}
161
162	// Check for /duplicate suffix
163	if strings.HasSuffix(path, "/duplicate") {
164		modelID := strings.TrimSuffix(path, "/duplicate")
165		if r.Method == http.MethodPost {
166			s.handleDuplicateModel(w, r, modelID)
167		} else {
168			http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
169		}
170		return
171	}
172
173	if strings.Contains(path, "/") {
174		http.Error(w, "Invalid model ID", http.StatusBadRequest)
175		return
176	}
177	modelID := path
178
179	switch r.Method {
180	case http.MethodGet:
181		s.handleGetModel(w, r, modelID)
182	case http.MethodPut:
183		s.handleUpdateModel(w, r, modelID)
184	case http.MethodDelete:
185		s.handleDeleteModel(w, r, modelID)
186	default:
187		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
188	}
189}
190
191func (s *Server) handleGetModel(w http.ResponseWriter, r *http.Request, modelID string) {
192	model, err := s.db.GetModel(r.Context(), modelID)
193	if err != nil {
194		http.Error(w, fmt.Sprintf("Failed to get model: %v", err), http.StatusNotFound)
195		return
196	}
197
198	w.Header().Set("Content-Type", "application/json")
199	json.NewEncoder(w).Encode(toModelAPI(*model))
200}
201
202func (s *Server) handleUpdateModel(w http.ResponseWriter, r *http.Request, modelID string) {
203	// First, get the existing model to get the current API key if not provided
204	existing, err := s.db.GetModel(r.Context(), modelID)
205	if err != nil {
206		http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
207		return
208	}
209
210	var req UpdateModelRequest
211	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
212		http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
213		return
214	}
215
216	// Use existing API key if not provided
217	apiKey := req.APIKey
218	if apiKey == "" {
219		apiKey = existing.ApiKey
220	}
221
222	// Default max tokens
223	if req.MaxTokens <= 0 {
224		req.MaxTokens = 200000
225	}
226
227	model, err := s.db.UpdateModel(r.Context(), generated.UpdateModelParams{
228		DisplayName:  req.DisplayName,
229		ProviderType: req.ProviderType,
230		Endpoint:     req.Endpoint,
231		ApiKey:       apiKey,
232		ModelName:    req.ModelName,
233		MaxTokens:    req.MaxTokens,
234		Tags:         req.Tags,
235		ModelID:      modelID,
236	})
237	if err != nil {
238		http.Error(w, fmt.Sprintf("Failed to update model: %v", err), http.StatusInternalServerError)
239		return
240	}
241
242	// Refresh the model manager's cache
243	if err := s.llmManager.RefreshCustomModels(); err != nil {
244		s.logger.Warn("Failed to refresh custom models cache", "error", err)
245	}
246
247	w.Header().Set("Content-Type", "application/json")
248	json.NewEncoder(w).Encode(toModelAPI(*model))
249}
250
251func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request, modelID string) {
252	err := s.db.DeleteModel(r.Context(), modelID)
253	if err != nil {
254		http.Error(w, fmt.Sprintf("Failed to delete model: %v", err), http.StatusInternalServerError)
255		return
256	}
257
258	// Refresh the model manager's cache
259	if err := s.llmManager.RefreshCustomModels(); err != nil {
260		s.logger.Warn("Failed to refresh custom models cache", "error", err)
261	}
262
263	w.WriteHeader(http.StatusNoContent)
264}
265
266// DuplicateModelRequest allows overriding fields when duplicating
267type DuplicateModelRequest struct {
268	DisplayName string `json:"display_name,omitempty"`
269}
270
271func (s *Server) handleDuplicateModel(w http.ResponseWriter, r *http.Request, modelID string) {
272	// Get the source model (including API key)
273	source, err := s.db.GetModel(r.Context(), modelID)
274	if err != nil {
275		http.Error(w, fmt.Sprintf("Source model not found: %v", err), http.StatusNotFound)
276		return
277	}
278
279	// Parse optional overrides
280	var req DuplicateModelRequest
281	if r.Body != nil {
282		json.NewDecoder(r.Body).Decode(&req) // Ignore errors - all fields optional
283	}
284
285	// Generate new model ID
286	newModelID := "custom-" + uuid.New().String()[:8]
287
288	// Use provided display name or generate one
289	displayName := req.DisplayName
290	if displayName == "" {
291		displayName = source.DisplayName + " (copy)"
292	}
293
294	// Create the duplicate with the same API key
295	model, err := s.db.CreateModel(r.Context(), generated.CreateModelParams{
296		ModelID:      newModelID,
297		DisplayName:  displayName,
298		ProviderType: source.ProviderType,
299		Endpoint:     source.Endpoint,
300		ApiKey:       source.ApiKey, // Copy the API key!
301		ModelName:    source.ModelName,
302		MaxTokens:    source.MaxTokens,
303		Tags:         "", // Don't copy tags
304	})
305	if err != nil {
306		http.Error(w, fmt.Sprintf("Failed to duplicate model: %v", err), http.StatusInternalServerError)
307		return
308	}
309
310	// Refresh the model manager's cache
311	if err := s.llmManager.RefreshCustomModels(); err != nil {
312		s.logger.Warn("Failed to refresh custom models cache", "error", err)
313	}
314
315	w.Header().Set("Content-Type", "application/json")
316	w.WriteHeader(http.StatusCreated)
317	json.NewEncoder(w).Encode(toModelAPI(*model))
318}
319
320func (s *Server) handleTestModel(w http.ResponseWriter, r *http.Request) {
321	if r.Method != http.MethodPost {
322		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
323		return
324	}
325
326	var req TestModelRequest
327	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
328		http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
329		return
330	}
331
332	// If model_id is provided and api_key is empty, look up the stored key
333	if req.ModelID != "" && req.APIKey == "" {
334		model, err := s.db.GetModel(r.Context(), req.ModelID)
335		if err != nil {
336			http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound)
337			return
338		}
339		req.APIKey = model.ApiKey
340	}
341
342	if req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" || req.ModelName == "" {
343		http.Error(w, "provider_type, endpoint, api_key, and model_name are required", http.StatusBadRequest)
344		return
345	}
346
347	// Create the appropriate service based on provider type
348	var service llm.Service
349	switch req.ProviderType {
350	case "anthropic":
351		service = &ant.Service{
352			APIKey: req.APIKey,
353			URL:    req.Endpoint,
354			Model:  req.ModelName,
355		}
356	case "openai":
357		service = &oai.Service{
358			APIKey:   req.APIKey,
359			ModelURL: req.Endpoint,
360			Model: oai.Model{
361				ModelName: req.ModelName,
362				URL:       req.Endpoint,
363			},
364		}
365	case "gemini":
366		service = &gem.Service{
367			APIKey: req.APIKey,
368			URL:    req.Endpoint,
369			Model:  req.ModelName,
370		}
371	case "openai-responses":
372		service = &oai.ResponsesService{
373			APIKey: req.APIKey,
374			Model: oai.Model{
375				ModelName: req.ModelName,
376				URL:       req.Endpoint,
377			},
378		}
379	default:
380		http.Error(w, "Invalid provider_type", http.StatusBadRequest)
381		return
382	}
383
384	// Send a simple test request
385	ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
386	defer cancel()
387
388	request := &llm.Request{
389		Messages: []llm.Message{
390			{
391				Role: llm.MessageRoleUser,
392				Content: []llm.Content{
393					{Type: llm.ContentTypeText, Text: "Say 'test successful' in exactly two words."},
394				},
395			},
396		},
397	}
398
399	response, err := service.Do(ctx, request)
400	if err != nil {
401		w.Header().Set("Content-Type", "application/json")
402		json.NewEncoder(w).Encode(map[string]interface{}{
403			"success": false,
404			"message": fmt.Sprintf("Test failed: %v", err),
405		})
406		return
407	}
408
409	// Check if we got a response
410	if len(response.Content) == 0 || response.Content[0].Text == "" {
411		w.Header().Set("Content-Type", "application/json")
412		json.NewEncoder(w).Encode(map[string]interface{}{
413			"success": false,
414			"message": "Test failed: empty response from model",
415		})
416		return
417	}
418
419	w.Header().Set("Content-Type", "application/json")
420	json.NewEncoder(w).Encode(map[string]interface{}{
421		"success": true,
422		"message": fmt.Sprintf("Test successful! Response: %s", response.Content[0].Text),
423	})
424}