google.go

  1package google
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"fmt"
  8	"maps"
  9	"net/http"
 10	"strings"
 11
 12	"github.com/charmbracelet/fantasy/ai"
 13	"google.golang.org/genai"
 14)
 15
 16type provider struct {
 17	options options
 18}
 19type options struct {
 20	apiKey  string
 21	name    string
 22	headers map[string]string
 23	client  *http.Client
 24}
 25
 26type Option = func(*options)
 27
 28func New(opts ...Option) ai.Provider {
 29	options := options{
 30		headers: map[string]string{},
 31	}
 32	for _, o := range opts {
 33		o(&options)
 34	}
 35
 36	if options.name == "" {
 37		options.name = "google"
 38	}
 39
 40	return &provider{
 41		options: options,
 42	}
 43}
 44
 45func WithAPIKey(apiKey string) Option {
 46	return func(o *options) {
 47		o.apiKey = apiKey
 48	}
 49}
 50
 51func WithName(name string) Option {
 52	return func(o *options) {
 53		o.name = name
 54	}
 55}
 56
 57func WithHeaders(headers map[string]string) Option {
 58	return func(o *options) {
 59		maps.Copy(o.headers, headers)
 60	}
 61}
 62
 63func WithHTTPClient(client *http.Client) Option {
 64	return func(o *options) {
 65		o.client = client
 66	}
 67}
 68
 69type languageModel struct {
 70	provider        string
 71	modelID         string
 72	client          *genai.Client
 73	providerOptions options
 74}
 75
 76// LanguageModel implements ai.Provider.
 77func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 78	cc := &genai.ClientConfig{
 79		APIKey:     g.options.apiKey,
 80		Backend:    genai.BackendGeminiAPI,
 81		HTTPClient: g.options.client,
 82	}
 83	client, err := genai.NewClient(context.Background(), cc)
 84	if err != nil {
 85		return nil, err
 86	}
 87	return &languageModel{
 88		modelID:         modelID,
 89		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 90		providerOptions: g.options,
 91		client:          client,
 92	}, nil
 93}
 94
 95func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
 96	config := &genai.GenerateContentConfig{}
 97	providerOptions := &providerOptions{}
 98	if v, ok := call.ProviderOptions["google"]; ok {
 99		err := ai.ParseOptions(v, providerOptions)
100		if err != nil {
101			return nil, nil, nil, err
102		}
103	}
104
105	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
106
107	if providerOptions.ThinkingConfig != nil &&
108		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
109		*providerOptions.ThinkingConfig.IncludeThoughts &&
110		strings.HasPrefix(a.provider, "google.vertex.") {
111		warnings = append(warnings, ai.CallWarning{
112			Type: ai.CallWarningTypeOther,
113			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
114				"and might not be supported or could behave unexpectedly with the current Google provider " +
115				fmt.Sprintf("(%s)", a.provider),
116		})
117	}
118
119	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
120
121	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
122		if len(content) > 0 && content[0].Role == genai.RoleUser {
123			systemParts := []string{}
124			for _, sp := range systemInstructions.Parts {
125				systemParts = append(systemParts, sp.Text)
126			}
127			systemMsg := strings.Join(systemParts, "\n")
128			content[0].Parts = append([]*genai.Part{
129				{
130					Text: systemMsg + "\n\n",
131				},
132			}, content[0].Parts...)
133			systemInstructions = nil
134		}
135	}
136
137	config.SystemInstruction = systemInstructions
138
139	if call.MaxOutputTokens != nil {
140		config.MaxOutputTokens = int32(*call.MaxOutputTokens)
141	}
142
143	if call.Temperature != nil {
144		tmp := float32(*call.Temperature)
145		config.Temperature = &tmp
146	}
147	if call.TopK != nil {
148		tmp := float32(*call.TopK)
149		config.TopK = &tmp
150	}
151	if call.TopP != nil {
152		tmp := float32(*call.TopP)
153		config.TopP = &tmp
154	}
155	if call.FrequencyPenalty != nil {
156		tmp := float32(*call.FrequencyPenalty)
157		config.FrequencyPenalty = &tmp
158	}
159	if call.PresencePenalty != nil {
160		tmp := float32(*call.PresencePenalty)
161		config.PresencePenalty = &tmp
162	}
163
164	if providerOptions.ThinkingConfig != nil {
165		config.ThinkingConfig = &genai.ThinkingConfig{}
166		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
167			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
168		}
169		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
170			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
171			config.ThinkingConfig.ThinkingBudget = &tmp
172		}
173	}
174	for _, safetySetting := range providerOptions.SafetySettings {
175		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
176			Category:  genai.HarmCategory(safetySetting.Category),
177			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
178		})
179	}
180	if providerOptions.CachedContent != "" {
181		config.CachedContent = providerOptions.CachedContent
182	}
183
184	if len(call.Tools) > 0 {
185		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
186		config.ToolConfig = toolChoice
187		config.Tools = append(config.Tools, &genai.Tool{
188			FunctionDeclarations: tools,
189		})
190		warnings = append(warnings, toolWarnings...)
191	}
192
193	return config, content, warnings, nil
194}
195
196func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
197	var systemInstructions *genai.Content
198	var content []*genai.Content
199	var warnings []ai.CallWarning
200
201	finishedSystemBlock := false
202	for _, msg := range prompt {
203		switch msg.Role {
204		case ai.MessageRoleSystem:
205			if finishedSystemBlock {
206				// skip multiple system messages that are separated by user/assistant messages
207				// TODO: see if we need to send error here?
208				continue
209			}
210			finishedSystemBlock = true
211
212			var systemMessages []string
213			for _, part := range msg.Content {
214				text, ok := ai.AsMessagePart[ai.TextPart](part)
215				if !ok || text.Text == "" {
216					continue
217				}
218				systemMessages = append(systemMessages, text.Text)
219			}
220			if len(systemMessages) > 0 {
221				systemInstructions = &genai.Content{
222					Parts: []*genai.Part{
223						{
224							Text: strings.Join(systemMessages, "\n"),
225						},
226					},
227				}
228			}
229		case ai.MessageRoleUser:
230			var parts []*genai.Part
231			for _, part := range msg.Content {
232				switch part.GetType() {
233				case ai.ContentTypeText:
234					text, ok := ai.AsMessagePart[ai.TextPart](part)
235					if !ok || text.Text == "" {
236						continue
237					}
238					parts = append(parts, &genai.Part{
239						Text: text.Text,
240					})
241				case ai.ContentTypeFile:
242					file, ok := ai.AsMessagePart[ai.FilePart](part)
243					if !ok {
244						continue
245					}
246					var encoded []byte
247					base64.StdEncoding.Encode(encoded, file.Data)
248					parts = append(parts, &genai.Part{
249						InlineData: &genai.Blob{
250							Data:     encoded,
251							MIMEType: file.MediaType,
252						},
253					})
254				}
255			}
256			if len(parts) > 0 {
257				content = append(content, &genai.Content{
258					Role:  genai.RoleUser,
259					Parts: parts,
260				})
261			}
262		case ai.MessageRoleAssistant:
263			var parts []*genai.Part
264			for _, part := range msg.Content {
265				switch part.GetType() {
266				case ai.ContentTypeText:
267					text, ok := ai.AsMessagePart[ai.TextPart](part)
268					if !ok || text.Text == "" {
269						continue
270					}
271					parts = append(parts, &genai.Part{
272						Text: text.Text,
273					})
274				case ai.ContentTypeToolCall:
275					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
276					if !ok {
277						continue
278					}
279
280					var result map[string]any
281					err := json.Unmarshal([]byte(toolCall.Input), &result)
282					if err != nil {
283						continue
284					}
285					parts = append(parts, &genai.Part{
286						FunctionCall: &genai.FunctionCall{
287							ID:   toolCall.ToolCallID,
288							Name: toolCall.ToolName,
289							Args: result,
290						},
291					})
292				}
293			}
294			if len(parts) > 0 {
295				content = append(content, &genai.Content{
296					Role:  genai.RoleModel,
297					Parts: parts,
298				})
299			}
300		case ai.MessageRoleTool:
301			var parts []*genai.Part
302			for _, part := range msg.Content {
303				switch part.GetType() {
304				case ai.ContentTypeToolResult:
305					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
306					if !ok {
307						continue
308					}
309					var toolCall ai.ToolCallPart
310					for _, m := range prompt {
311						if m.Role == ai.MessageRoleAssistant {
312							for _, content := range m.Content {
313								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
314								if !ok {
315									continue
316								}
317								if tc.ToolCallID == result.ToolCallID {
318									toolCall = tc
319									break
320								}
321							}
322						}
323					}
324					switch result.Output.GetType() {
325					case ai.ToolResultContentTypeText:
326						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
327						if !ok {
328							continue
329						}
330						response := map[string]any{"result": content.Text}
331						parts = append(parts, &genai.Part{
332							FunctionResponse: &genai.FunctionResponse{
333								ID:       result.ToolCallID,
334								Response: response,
335								Name:     toolCall.ToolName,
336							},
337						})
338
339					case ai.ToolResultContentTypeError:
340						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
341						if !ok {
342							continue
343						}
344						response := map[string]any{"result": content.Error.Error()}
345						parts = append(parts, &genai.Part{
346							FunctionResponse: &genai.FunctionResponse{
347								ID:       result.ToolCallID,
348								Response: response,
349								Name:     toolCall.ToolName,
350							},
351						})
352
353					}
354				}
355			}
356			if len(parts) > 0 {
357				content = append(content, &genai.Content{
358					Role:  genai.RoleUser,
359					Parts: parts,
360				})
361			}
362		}
363	}
364	return systemInstructions, content, warnings
365}
366
367// Generate implements ai.LanguageModel.
368func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
369	// params, err := g.prepareParams(call)
370	// if err != nil {
371	// 	return nil, err
372	// }
373	panic("unimplemented")
374}
375
376// Model implements ai.LanguageModel.
377func (g *languageModel) Model() string {
378	return g.modelID
379}
380
381// Provider implements ai.LanguageModel.
382func (g *languageModel) Provider() string {
383	return g.provider
384}
385
386// Stream implements ai.LanguageModel.
387func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
388	panic("unimplemented")
389}
390
391func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
392	for _, tool := range tools {
393		if tool.GetType() == ai.ToolTypeFunction {
394			ft, ok := tool.(ai.FunctionTool)
395			if !ok {
396				continue
397			}
398
399			required := []string{}
400			var properties map[string]any
401			if props, ok := ft.InputSchema["properties"]; ok {
402				properties, _ = props.(map[string]any)
403			}
404			if req, ok := ft.InputSchema["required"]; ok {
405				if reqArr, ok := req.([]string); ok {
406					required = reqArr
407				}
408			}
409			declaration := &genai.FunctionDeclaration{
410				Name:        ft.Name,
411				Description: ft.Description,
412				Parameters: &genai.Schema{
413					Type:       genai.TypeObject,
414					Properties: convertSchemaProperties(properties),
415					Required:   required,
416				},
417			}
418			googleTools = append(googleTools, declaration)
419			continue
420		}
421		// TODO: handle provider tool calls
422		warnings = append(warnings, ai.CallWarning{
423			Type:    ai.CallWarningTypeUnsupportedTool,
424			Tool:    tool,
425			Message: "tool is not supported",
426		})
427	}
428	if toolChoice == nil {
429		return
430	}
431	switch *toolChoice {
432	case ai.ToolChoiceAuto:
433		googleToolChoice = &genai.ToolConfig{
434			FunctionCallingConfig: &genai.FunctionCallingConfig{
435				Mode: genai.FunctionCallingConfigModeAuto,
436			},
437		}
438	case ai.ToolChoiceRequired:
439		googleToolChoice = &genai.ToolConfig{
440			FunctionCallingConfig: &genai.FunctionCallingConfig{
441				Mode: genai.FunctionCallingConfigModeAny,
442			},
443		}
444	case ai.ToolChoiceNone:
445		googleToolChoice = &genai.ToolConfig{
446			FunctionCallingConfig: &genai.FunctionCallingConfig{
447				Mode: genai.FunctionCallingConfigModeNone,
448			},
449		}
450	default:
451		googleToolChoice = &genai.ToolConfig{
452			FunctionCallingConfig: &genai.FunctionCallingConfig{
453				Mode: genai.FunctionCallingConfigModeAny,
454				AllowedFunctionNames: []string{
455					string(*toolChoice),
456				},
457			},
458		}
459	}
460	return
461}
462
463func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
464	properties := make(map[string]*genai.Schema)
465
466	for name, param := range parameters {
467		properties[name] = convertToSchema(param)
468	}
469
470	return properties
471}
472
473func convertToSchema(param any) *genai.Schema {
474	schema := &genai.Schema{Type: genai.TypeString}
475
476	paramMap, ok := param.(map[string]any)
477	if !ok {
478		return schema
479	}
480
481	if desc, ok := paramMap["description"].(string); ok {
482		schema.Description = desc
483	}
484
485	typeVal, hasType := paramMap["type"]
486	if !hasType {
487		return schema
488	}
489
490	typeStr, ok := typeVal.(string)
491	if !ok {
492		return schema
493	}
494
495	schema.Type = mapJSONTypeToGoogle(typeStr)
496
497	switch typeStr {
498	case "array":
499		schema.Items = processArrayItems(paramMap)
500	case "object":
501		if props, ok := paramMap["properties"].(map[string]any); ok {
502			schema.Properties = convertSchemaProperties(props)
503		}
504	}
505
506	return schema
507}
508
509func processArrayItems(paramMap map[string]any) *genai.Schema {
510	items, ok := paramMap["items"].(map[string]any)
511	if !ok {
512		return nil
513	}
514
515	return convertToSchema(items)
516}
517
518func mapJSONTypeToGoogle(jsonType string) genai.Type {
519	switch jsonType {
520	case "string":
521		return genai.TypeString
522	case "number":
523		return genai.TypeNumber
524	case "integer":
525		return genai.TypeInteger
526	case "boolean":
527		return genai.TypeBoolean
528	case "array":
529		return genai.TypeArray
530	case "object":
531		return genai.TypeObject
532	default:
533		return genai.TypeString // Default to string for unknown types
534	}
535}