gem.go

  1package gem
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"math/rand"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"shelley.exe.dev/llm"
 15	"shelley.exe.dev/llm/gem/gemini"
 16)
 17
 18const (
 19	DefaultModel    = "gemini-2.5-pro"
 20	GeminiAPIKeyEnv = "GEMINI_API_KEY"
 21)
 22
 23// Service provides Gemini completions.
 24// Fields should not be altered concurrently with calling any method on Service.
 25type Service struct {
 26	HTTPC  *http.Client // defaults to http.DefaultClient if nil
 27	URL    string       // Gemini API URL, uses the gemini package default if empty
 28	APIKey string       // must be non-empty
 29	Model  string       // defaults to DefaultModel if empty
 30}
 31
 32var _ llm.Service = (*Service)(nil)
 33
 34// These maps convert between Sketch's llm package and Gemini API formats
 35var fromLLMRole = map[llm.MessageRole]string{
 36	llm.MessageRoleAssistant: "model",
 37	llm.MessageRoleUser:      "user",
 38}
 39
 40// convertToolSchemas converts Sketch's llm.Tool schemas to Gemini's schema format
 41func convertToolSchemas(tools []*llm.Tool) ([]gemini.FunctionDeclaration, error) {
 42	if len(tools) == 0 {
 43		return nil, nil
 44	}
 45
 46	var decls []gemini.FunctionDeclaration
 47	for _, tool := range tools {
 48		// Parse the schema from raw JSON
 49		var schemaJSON map[string]any
 50		if err := json.Unmarshal(tool.InputSchema, &schemaJSON); err != nil {
 51			return nil, fmt.Errorf("failed to unmarshal tool %s schema: %w", tool.Name, err)
 52		}
 53		decls = append(decls, gemini.FunctionDeclaration{
 54			Name:        tool.Name,
 55			Description: tool.Description,
 56			Parameters:  convertJSONSchemaToGeminiSchema(schemaJSON),
 57		})
 58	}
 59
 60	return decls, nil
 61}
 62
 63// convertJSONSchemaToGeminiSchema converts a JSON schema to Gemini's schema format
 64func convertJSONSchemaToGeminiSchema(schemaJSON map[string]any) gemini.Schema {
 65	schema := gemini.Schema{}
 66
 67	// Set the type based on the JSON schema type
 68	if typeVal, ok := schemaJSON["type"].(string); ok {
 69		switch typeVal {
 70		case "string":
 71			schema.Type = gemini.DataTypeSTRING
 72		case "number":
 73			schema.Type = gemini.DataTypeNUMBER
 74		case "integer":
 75			schema.Type = gemini.DataTypeINTEGER
 76		case "boolean":
 77			schema.Type = gemini.DataTypeBOOLEAN
 78		case "array":
 79			schema.Type = gemini.DataTypeARRAY
 80		case "object":
 81			schema.Type = gemini.DataTypeOBJECT
 82		default:
 83			schema.Type = gemini.DataTypeSTRING // Default to string for unknown types
 84		}
 85	}
 86
 87	// Set description if available
 88	if desc, ok := schemaJSON["description"].(string); ok {
 89		schema.Description = desc
 90	}
 91
 92	// Handle enum values
 93	if enumValues, ok := schemaJSON["enum"].([]any); ok {
 94		schema.Enum = make([]string, len(enumValues))
 95		for i, v := range enumValues {
 96			if strVal, ok := v.(string); ok {
 97				schema.Enum[i] = strVal
 98			} else {
 99				// Convert non-string values to string
100				valBytes, _ := json.Marshal(v)
101				schema.Enum[i] = string(valBytes)
102			}
103		}
104	}
105
106	// Handle object properties
107	if properties, ok := schemaJSON["properties"].(map[string]any); ok && schema.Type == gemini.DataTypeOBJECT {
108		schema.Properties = make(map[string]gemini.Schema)
109		for propName, propSchema := range properties {
110			if propSchemaMap, ok := propSchema.(map[string]any); ok {
111				schema.Properties[propName] = convertJSONSchemaToGeminiSchema(propSchemaMap)
112			}
113		}
114	}
115
116	// Handle required properties
117	if required, ok := schemaJSON["required"].([]any); ok {
118		schema.Required = make([]string, len(required))
119		for i, r := range required {
120			if strVal, ok := r.(string); ok {
121				schema.Required[i] = strVal
122			}
123		}
124	}
125
126	// Handle array items
127	if items, ok := schemaJSON["items"].(map[string]any); ok && schema.Type == gemini.DataTypeARRAY {
128		itemSchema := convertJSONSchemaToGeminiSchema(items)
129		schema.Items = &itemSchema
130	}
131
132	// Handle minimum/maximum items for arrays
133	if minItems, ok := schemaJSON["minItems"].(float64); ok {
134		schema.MinItems = fmt.Sprintf("%d", int(minItems))
135	}
136	if maxItems, ok := schemaJSON["maxItems"].(float64); ok {
137		schema.MaxItems = fmt.Sprintf("%d", int(maxItems))
138	}
139
140	return schema
141}
142
143// buildGeminiRequest converts Sketch's llm.Request to Gemini's request format
144func (s *Service) buildGeminiRequest(req *llm.Request) (*gemini.Request, error) {
145	gemReq := &gemini.Request{}
146
147	// Add system instruction if provided
148	if len(req.System) > 0 {
149		// Combine all system messages into a single system instruction
150		systemText := ""
151		for i, sys := range req.System {
152			if i > 0 && systemText != "" && sys.Text != "" {
153				systemText += "\n"
154			}
155			systemText += sys.Text
156		}
157
158		if systemText != "" {
159			gemReq.SystemInstruction = &gemini.Content{
160				Parts: []gemini.Part{{Text: systemText}},
161			}
162		}
163	}
164
165	// Convert messages to Gemini content format
166	for _, msg := range req.Messages {
167		// Set the role based on the message role
168		role, ok := fromLLMRole[msg.Role]
169		if !ok {
170			return nil, fmt.Errorf("unsupported message role: %v", msg.Role)
171		}
172
173		content := gemini.Content{
174			Role: role,
175		}
176
177		// Store tool usage information to correlate tool uses with responses
178		toolNameToID := make(map[string]string)
179
180		// First pass: collect tool use IDs for correlation
181		for _, c := range msg.Content {
182			if c.Type == llm.ContentTypeToolUse && c.ID != "" {
183				toolNameToID[c.ToolName] = c.ID
184			}
185		}
186
187		// Map each content item to Gemini's format
188		for _, c := range msg.Content {
189			switch c.Type {
190			case llm.ContentTypeText, llm.ContentTypeThinking, llm.ContentTypeRedactedThinking:
191				// Simple text content
192				content.Parts = append(content.Parts, gemini.Part{
193					Text: c.Text,
194				})
195			case llm.ContentTypeToolUse:
196				// Tool use becomes a function call
197				var args map[string]any
198				if err := json.Unmarshal(c.ToolInput, &args); err != nil {
199					return nil, fmt.Errorf("failed to unmarshal tool input: %w", err)
200				}
201
202				// Make sure we have a valid ID for this tool use
203				if c.ID == "" {
204					c.ID = fmt.Sprintf("gemini_tool_%s_%d", c.ToolName, time.Now().UnixNano())
205				}
206
207				// Save the ID for this tool name for future correlation
208				toolNameToID[c.ToolName] = c.ID
209
210				slog.DebugContext(context.Background(), "gemini_preparing_tool_use",
211					"tool_name", c.ToolName,
212					"tool_id", c.ID,
213					"input", string(c.ToolInput),
214					"thought_signature", c.Signature)
215
216				content.Parts = append(content.Parts, gemini.Part{
217					FunctionCall: &gemini.FunctionCall{
218						Name: c.ToolName,
219						Args: args,
220					},
221					// Gemini 3 requires thought signatures to be passed back for function calls
222					ThoughtSignature: c.Signature,
223				})
224			case llm.ContentTypeToolResult:
225				// Tool result becomes a function response
226				// Create a map for the response
227				response := map[string]any{
228					"error": c.ToolError,
229				}
230
231				// Handle tool results: Gemini only supports string results
232				// Combine all text content into a single string
233				var resultText string
234				if len(c.ToolResult) > 0 {
235					// Collect all text from content objects
236					texts := make([]string, 0, len(c.ToolResult))
237					for _, result := range c.ToolResult {
238						if result.Text != "" {
239							texts = append(texts, result.Text)
240						}
241					}
242					resultText = strings.Join(texts, "\n")
243				}
244				response["result"] = resultText
245
246				// Determine the function name to use - this is critical
247				funcName := ""
248
249				// First try to find the function name from a stored toolUseID if we have one
250				if c.ToolUseID != "" {
251					// Try to derive the tool name from the previous tools we've seen
252					for name, id := range toolNameToID {
253						if id == c.ToolUseID {
254							funcName = name
255							break
256						}
257					}
258				}
259
260				// Fallback options if we couldn't find the tool name
261				if funcName == "" {
262					// Try the tool name directly
263					if c.ToolName != "" {
264						funcName = c.ToolName
265					} else {
266						// Last resort fallback
267						funcName = "default_tool"
268					}
269				}
270
271				slog.DebugContext(context.Background(), "gemini_preparing_tool_result",
272					"tool_use_id", c.ToolUseID,
273					"mapped_func_name", funcName,
274					"result_count", len(c.ToolResult))
275
276				content.Parts = append(content.Parts, gemini.Part{
277					FunctionResponse: &gemini.FunctionResponse{
278						Name:     funcName,
279						Response: response,
280					},
281				})
282			}
283		}
284
285		gemReq.Contents = append(gemReq.Contents, content)
286	}
287
288	// Handle tools/functions
289	if len(req.Tools) > 0 {
290		// Convert tool schemas
291		decls, err := convertToolSchemas(req.Tools)
292		if err != nil {
293			return nil, fmt.Errorf("failed to convert tool schemas: %w", err)
294		}
295		if len(decls) > 0 {
296			gemReq.Tools = []gemini.Tool{{FunctionDeclarations: decls}}
297		}
298	}
299
300	return gemReq, nil
301}
302
303// convertGeminiResponsesToContent converts a Gemini response to llm.Content
304func convertGeminiResponseToContent(res *gemini.Response) []llm.Content {
305	if res == nil || len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
306		return []llm.Content{{
307			Type: llm.ContentTypeText,
308			Text: "",
309		}}
310	}
311
312	var contents []llm.Content
313
314	// Process each part in the first candidate's content
315	for i, part := range res.Candidates[0].Content.Parts {
316		// Log the part type for debugging
317		slog.DebugContext(context.Background(), "processing_gemini_part",
318			"index", i,
319			"has_text", part.Text != "",
320			"has_function_call", part.FunctionCall != nil,
321			"has_function_response", part.FunctionResponse != nil)
322
323		if part.Text != "" {
324			// Simple text response
325			contents = append(contents, llm.Content{
326				Type:      llm.ContentTypeText,
327				Text:      part.Text,
328				Signature: part.ThoughtSignature, // Capture thought signature for text parts too
329			})
330		} else if part.FunctionCall != nil {
331			// Function call (tool use)
332			args, err := json.Marshal(part.FunctionCall.Args)
333			if err != nil {
334				// If we can't marshal, use empty args
335				slog.DebugContext(context.Background(), "gemini_failed_to_marshal_args",
336					"tool_name", part.FunctionCall.Name,
337					"args", string(args),
338					"err", err.Error(),
339				)
340				args = []byte("{}")
341			}
342
343			// Generate a unique ID for this tool use that includes the function name
344			// to make it easier to correlate with responses
345			toolID := fmt.Sprintf("gemini_tool_%s_%d", part.FunctionCall.Name, time.Now().UnixNano())
346
347			contents = append(contents, llm.Content{
348				ID:        toolID,
349				Type:      llm.ContentTypeToolUse,
350				ToolName:  part.FunctionCall.Name,
351				ToolInput: json.RawMessage(args),
352				// Capture thought signature - required for Gemini 3 function calling
353				Signature: part.ThoughtSignature,
354			})
355
356			slog.DebugContext(context.Background(), "gemini_tool_call",
357				"tool_id", toolID,
358				"tool_name", part.FunctionCall.Name,
359				"args", string(args),
360				"thought_signature", part.ThoughtSignature)
361		} else if part.FunctionResponse != nil {
362			// We shouldn't normally get function responses from the model, but just in case
363			respData, _ := json.Marshal(part.FunctionResponse.Response)
364			slog.DebugContext(context.Background(), "unexpected_function_response",
365				"name", part.FunctionResponse.Name,
366				"response", string(respData))
367		}
368	}
369
370	// If no content was added, add an empty text content
371	if len(contents) == 0 {
372		slog.DebugContext(context.Background(), "empty_gemini_response", "adding_empty_text", true)
373		contents = append(contents, llm.Content{
374			Type: llm.ContentTypeText,
375			Text: "",
376		})
377	}
378
379	return contents
380}
381
382// Gemini doesn't provide usage info directly, so we need to estimate it
383// ensureToolIDs makes sure all tool uses have proper IDs
384func ensureToolIDs(contents []llm.Content) {
385	for i, content := range contents {
386		if content.Type == llm.ContentTypeToolUse && content.ID == "" {
387			// Generate a stable ID using the tool name and timestamp
388			contents[i].ID = fmt.Sprintf("gemini_tool_%s_%d", content.ToolName, time.Now().UnixNano())
389			slog.DebugContext(context.Background(), "assigned_missing_tool_id",
390				"tool_name", content.ToolName,
391				"new_id", contents[i].ID)
392		}
393	}
394}
395
396func calculateUsage(req *gemini.Request, res *gemini.Response) llm.Usage {
397	// Very rough estimation of token counts
398	var inputTokens uint64
399	var outputTokens uint64
400
401	// Count system tokens
402	if req.SystemInstruction != nil {
403		for _, part := range req.SystemInstruction.Parts {
404			if part.Text != "" {
405				// Very rough estimation: 1 token per 4 characters
406				inputTokens += uint64(len(part.Text)) / 4
407			}
408		}
409	}
410
411	// Count input tokens
412	for _, content := range req.Contents {
413		for _, part := range content.Parts {
414			if part.Text != "" {
415				inputTokens += uint64(len(part.Text)) / 4
416			} else if part.FunctionCall != nil {
417				// Estimate function call tokens
418				argBytes, _ := json.Marshal(part.FunctionCall.Args)
419				inputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
420			} else if part.FunctionResponse != nil {
421				// Estimate function response tokens
422				resBytes, _ := json.Marshal(part.FunctionResponse.Response)
423				inputTokens += uint64(len(part.FunctionResponse.Name)+len(resBytes)) / 4
424			}
425		}
426	}
427
428	// Count output tokens
429	if res != nil && len(res.Candidates) > 0 {
430		for _, part := range res.Candidates[0].Content.Parts {
431			if part.Text != "" {
432				outputTokens += uint64(len(part.Text)) / 4
433			} else if part.FunctionCall != nil {
434				// Estimate function call tokens
435				argBytes, _ := json.Marshal(part.FunctionCall.Args)
436				outputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
437			}
438		}
439	}
440
441	return llm.Usage{
442		InputTokens:  inputTokens,
443		OutputTokens: outputTokens,
444	}
445}
446
447// TokenContextWindow returns the maximum token context window size for this service
448func (s *Service) TokenContextWindow() int {
449	model := s.Model
450	if model == "" {
451		model = DefaultModel
452	}
453
454	// Gemini models generally have large context windows
455	switch model {
456	case "gemini-3-pro-preview", "gemini-3-flash-preview":
457		return 1000000 // 1M tokens for Gemini 3
458	case "gemini-2.5-pro", "gemini-2.5-flash":
459		return 1000000 // 1M tokens for Gemini 2.5
460	case "gemini-2.0-flash-exp", "gemini-2.0-flash":
461		return 1000000 // 1M tokens for Gemini 2.0 Flash
462	case "gemini-1.5-pro", "gemini-1.5-pro-latest":
463		return 2000000 // 2M tokens for Gemini 1.5 Pro
464	case "gemini-1.5-flash", "gemini-1.5-flash-latest":
465		return 1000000 // 1M tokens for Gemini 1.5 Flash
466	default:
467		// Default for unknown models
468		return 1000000
469	}
470}
471
472// MaxImageDimension returns the maximum allowed image dimension.
473// TODO: determine actual Gemini image dimension limits
474func (s *Service) MaxImageDimension() int {
475	return 0 // No known limit
476}
477
478// Do sends a request to Gemini.
479func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
480	// Log the incoming request for debugging
481	slog.DebugContext(ctx, "gemini_request",
482		"message_count", len(ir.Messages),
483		"tool_count", len(ir.Tools),
484		"system_count", len(ir.System))
485
486	// Log tool-related information if any tools are present
487	if len(ir.Tools) > 0 {
488		var toolNames []string
489		for _, tool := range ir.Tools {
490			toolNames = append(toolNames, tool.Name)
491		}
492		slog.DebugContext(ctx, "gemini_tools", "tools", toolNames)
493	}
494
495	// Log details about the messages being sent
496	for i, msg := range ir.Messages {
497		contentTypes := make([]string, len(msg.Content))
498		for j, c := range msg.Content {
499			contentTypes[j] = c.Type.String()
500
501			// Log tool-related content with more details
502			if c.Type == llm.ContentTypeToolUse {
503				slog.DebugContext(ctx, "gemini_tool_use",
504					"message_idx", i,
505					"content_idx", j,
506					"tool_name", c.ToolName,
507					"tool_input", string(c.ToolInput))
508			} else if c.Type == llm.ContentTypeToolResult {
509				slog.DebugContext(ctx, "gemini_tool_result",
510					"message_idx", i,
511					"content_idx", j,
512					"tool_use_id", c.ToolUseID,
513					"tool_error", c.ToolError,
514					"result_count", len(c.ToolResult))
515			}
516		}
517		slog.DebugContext(ctx, "gemini_message",
518			"idx", i,
519			"role", msg.Role.String(),
520			"content_types", contentTypes)
521	}
522	// Build the Gemini request
523	gemReq, err := s.buildGeminiRequest(ir)
524	if err != nil {
525		return nil, fmt.Errorf("failed to build Gemini request: %w", err)
526	}
527
528	// Log the structured Gemini request for debugging
529	if reqJSON, err := json.MarshalIndent(gemReq, "", "  "); err == nil {
530		slog.DebugContext(ctx, "gemini_request_json", "request", string(reqJSON))
531	}
532
533	// Create a Gemini model instance
534	model := gemini.Model{
535		Model:    "models/" + cmp.Or(s.Model, DefaultModel),
536		Endpoint: s.URL,
537		APIKey:   s.APIKey,
538		HTTPC:    cmp.Or(s.HTTPC, http.DefaultClient),
539	}
540
541	// Send the request to Gemini with retry logic
542	startTime := time.Now()
543	endTime := startTime // Initialize endTime
544	var gemRes *gemini.Response
545
546	// Retry mechanism for handling server errors and rate limiting
547	backoff := []time.Duration{1 * time.Second, 3 * time.Second, 5 * time.Second, 10 * time.Second}
548	for attempts := 0; attempts <= len(backoff); attempts++ {
549		gemApiErr := error(nil)
550		gemRes, gemApiErr = model.GenerateContent(ctx, gemReq)
551		endTime = time.Now()
552
553		if gemApiErr == nil {
554			// Successful response
555			// Log the structured Gemini response
556			if resJSON, err := json.MarshalIndent(gemRes, "", "  "); err == nil {
557				slog.DebugContext(ctx, "gemini_response_json", "response", string(resJSON))
558			}
559			break
560		}
561
562		if attempts == len(backoff) {
563			// We've exhausted all retry attempts
564			return nil, fmt.Errorf("gemini: API error after %d attempts: %w", attempts, gemApiErr)
565		}
566
567		// Check if the error is retryable (e.g., server error or rate limiting)
568		if strings.Contains(gemApiErr.Error(), "429") || strings.Contains(gemApiErr.Error(), "5") {
569			// Rate limited or server error - wait and retry
570			random := time.Duration(rand.Int63n(int64(time.Second)))
571			sleep := backoff[attempts] + random
572			slog.WarnContext(ctx, "gemini_request_retry", "error", gemApiErr.Error(), "attempt", attempts+1, "sleep", sleep)
573			time.Sleep(sleep)
574			continue
575		}
576
577		// Non-retryable error
578		return nil, fmt.Errorf("gemini: API error: %w", gemApiErr)
579	}
580
581	content := convertGeminiResponseToContent(gemRes)
582
583	ensureToolIDs(content)
584
585	usage := calculateUsage(gemReq, gemRes)
586	usage.CostUSD = llm.CostUSDFromResponse(gemRes.Header())
587
588	stopReason := llm.StopReasonEndTurn
589	for _, part := range content {
590		if part.Type == llm.ContentTypeToolUse {
591			stopReason = llm.StopReasonToolUse
592			slog.DebugContext(ctx, "gemini_tool_use_detected",
593				"setting_stop_reason", "llm.StopReasonToolUse",
594				"tool_name", part.ToolName)
595			break
596		}
597	}
598
599	return &llm.Response{
600		Role:       llm.MessageRoleAssistant,
601		Model:      s.Model,
602		Content:    content,
603		StopReason: stopReason,
604		Usage:      usage,
605		StartTime:  &startTime,
606		EndTime:    &endTime,
607	}, nil
608}