google.go

  1package google
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"maps"
 10	"net/http"
 11	"strings"
 12
 13	"github.com/charmbracelet/fantasy/ai"
 14	"google.golang.org/genai"
 15)
 16
 17type provider struct {
 18	options options
 19}
 20type options struct {
 21	apiKey  string
 22	name    string
 23	headers map[string]string
 24	client  *http.Client
 25}
 26
 27type Option = func(*options)
 28
 29func New(opts ...Option) ai.Provider {
 30	options := options{
 31		headers: map[string]string{},
 32	}
 33	for _, o := range opts {
 34		o(&options)
 35	}
 36
 37	if options.name == "" {
 38		options.name = "google"
 39	}
 40
 41	return &provider{
 42		options: options,
 43	}
 44}
 45
 46func WithAPIKey(apiKey string) Option {
 47	return func(o *options) {
 48		o.apiKey = apiKey
 49	}
 50}
 51
 52func WithName(name string) Option {
 53	return func(o *options) {
 54		o.name = name
 55	}
 56}
 57
 58func WithHeaders(headers map[string]string) Option {
 59	return func(o *options) {
 60		maps.Copy(o.headers, headers)
 61	}
 62}
 63
 64func WithHTTPClient(client *http.Client) Option {
 65	return func(o *options) {
 66		o.client = client
 67	}
 68}
 69
 70type languageModel struct {
 71	provider        string
 72	modelID         string
 73	client          *genai.Client
 74	providerOptions options
 75}
 76
 77// LanguageModel implements ai.Provider.
 78func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 79	cc := &genai.ClientConfig{
 80		APIKey:     g.options.apiKey,
 81		Backend:    genai.BackendGeminiAPI,
 82		HTTPClient: g.options.client,
 83	}
 84	client, err := genai.NewClient(context.Background(), cc)
 85	if err != nil {
 86		return nil, err
 87	}
 88	return &languageModel{
 89		modelID:         modelID,
 90		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 91		providerOptions: g.options,
 92		client:          client,
 93	}, nil
 94}
 95
 96func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
 97	config := &genai.GenerateContentConfig{}
 98	providerOptions := &providerOptions{}
 99	if v, ok := call.ProviderOptions["google"]; ok {
100		err := ai.ParseOptions(v, providerOptions)
101		if err != nil {
102			return nil, nil, nil, err
103		}
104	}
105
106	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
107
108	if providerOptions.ThinkingConfig != nil &&
109		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
110		*providerOptions.ThinkingConfig.IncludeThoughts &&
111		strings.HasPrefix(a.provider, "google.vertex.") {
112		warnings = append(warnings, ai.CallWarning{
113			Type: ai.CallWarningTypeOther,
114			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
115				"and might not be supported or could behave unexpectedly with the current Google provider " +
116				fmt.Sprintf("(%s)", a.provider),
117		})
118	}
119
120	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
121
122	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
123		if len(content) > 0 && content[0].Role == genai.RoleUser {
124			systemParts := []string{}
125			for _, sp := range systemInstructions.Parts {
126				systemParts = append(systemParts, sp.Text)
127			}
128			systemMsg := strings.Join(systemParts, "\n")
129			content[0].Parts = append([]*genai.Part{
130				{
131					Text: systemMsg + "\n\n",
132				},
133			}, content[0].Parts...)
134			systemInstructions = nil
135		}
136	}
137
138	config.SystemInstruction = systemInstructions
139
140	if call.MaxOutputTokens != nil {
141		config.MaxOutputTokens = int32(*call.MaxOutputTokens)
142	}
143
144	if call.Temperature != nil {
145		tmp := float32(*call.Temperature)
146		config.Temperature = &tmp
147	}
148	if call.TopK != nil {
149		tmp := float32(*call.TopK)
150		config.TopK = &tmp
151	}
152	if call.TopP != nil {
153		tmp := float32(*call.TopP)
154		config.TopP = &tmp
155	}
156	if call.FrequencyPenalty != nil {
157		tmp := float32(*call.FrequencyPenalty)
158		config.FrequencyPenalty = &tmp
159	}
160	if call.PresencePenalty != nil {
161		tmp := float32(*call.PresencePenalty)
162		config.PresencePenalty = &tmp
163	}
164
165	if providerOptions.ThinkingConfig != nil {
166		config.ThinkingConfig = &genai.ThinkingConfig{}
167		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
168			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
169		}
170		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
171			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
172			config.ThinkingConfig.ThinkingBudget = &tmp
173		}
174	}
175	for _, safetySetting := range providerOptions.SafetySettings {
176		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
177			Category:  genai.HarmCategory(safetySetting.Category),
178			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
179		})
180	}
181	if providerOptions.CachedContent != "" {
182		config.CachedContent = providerOptions.CachedContent
183	}
184
185	if len(call.Tools) > 0 {
186		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
187		config.ToolConfig = toolChoice
188		config.Tools = append(config.Tools, &genai.Tool{
189			FunctionDeclarations: tools,
190		})
191		warnings = append(warnings, toolWarnings...)
192	}
193
194	return config, content, warnings, nil
195}
196
197func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
198	var systemInstructions *genai.Content
199	var content []*genai.Content
200	var warnings []ai.CallWarning
201
202	finishedSystemBlock := false
203	for _, msg := range prompt {
204		switch msg.Role {
205		case ai.MessageRoleSystem:
206			if finishedSystemBlock {
207				// skip multiple system messages that are separated by user/assistant messages
208				// TODO: see if we need to send error here?
209				continue
210			}
211			finishedSystemBlock = true
212
213			var systemMessages []string
214			for _, part := range msg.Content {
215				text, ok := ai.AsMessagePart[ai.TextPart](part)
216				if !ok || text.Text == "" {
217					continue
218				}
219				systemMessages = append(systemMessages, text.Text)
220			}
221			if len(systemMessages) > 0 {
222				systemInstructions = &genai.Content{
223					Parts: []*genai.Part{
224						{
225							Text: strings.Join(systemMessages, "\n"),
226						},
227					},
228				}
229			}
230		case ai.MessageRoleUser:
231			var parts []*genai.Part
232			for _, part := range msg.Content {
233				switch part.GetType() {
234				case ai.ContentTypeText:
235					text, ok := ai.AsMessagePart[ai.TextPart](part)
236					if !ok || text.Text == "" {
237						continue
238					}
239					parts = append(parts, &genai.Part{
240						Text: text.Text,
241					})
242				case ai.ContentTypeFile:
243					file, ok := ai.AsMessagePart[ai.FilePart](part)
244					if !ok {
245						continue
246					}
247					var encoded []byte
248					base64.StdEncoding.Encode(encoded, file.Data)
249					parts = append(parts, &genai.Part{
250						InlineData: &genai.Blob{
251							Data:     encoded,
252							MIMEType: file.MediaType,
253						},
254					})
255				}
256			}
257			if len(parts) > 0 {
258				content = append(content, &genai.Content{
259					Role:  genai.RoleUser,
260					Parts: parts,
261				})
262			}
263		case ai.MessageRoleAssistant:
264			var parts []*genai.Part
265			for _, part := range msg.Content {
266				switch part.GetType() {
267				case ai.ContentTypeText:
268					text, ok := ai.AsMessagePart[ai.TextPart](part)
269					if !ok || text.Text == "" {
270						continue
271					}
272					parts = append(parts, &genai.Part{
273						Text: text.Text,
274					})
275				case ai.ContentTypeToolCall:
276					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
277					if !ok {
278						continue
279					}
280
281					var result map[string]any
282					err := json.Unmarshal([]byte(toolCall.Input), &result)
283					if err != nil {
284						continue
285					}
286					parts = append(parts, &genai.Part{
287						FunctionCall: &genai.FunctionCall{
288							ID:   toolCall.ToolCallID,
289							Name: toolCall.ToolName,
290							Args: result,
291						},
292					})
293				}
294			}
295			if len(parts) > 0 {
296				content = append(content, &genai.Content{
297					Role:  genai.RoleModel,
298					Parts: parts,
299				})
300			}
301		case ai.MessageRoleTool:
302			var parts []*genai.Part
303			for _, part := range msg.Content {
304				switch part.GetType() {
305				case ai.ContentTypeToolResult:
306					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
307					if !ok {
308						continue
309					}
310					var toolCall ai.ToolCallPart
311					for _, m := range prompt {
312						if m.Role == ai.MessageRoleAssistant {
313							for _, content := range m.Content {
314								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
315								if !ok {
316									continue
317								}
318								if tc.ToolCallID == result.ToolCallID {
319									toolCall = tc
320									break
321								}
322							}
323						}
324					}
325					switch result.Output.GetType() {
326					case ai.ToolResultContentTypeText:
327						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
328						if !ok {
329							continue
330						}
331						response := map[string]any{"result": content.Text}
332						parts = append(parts, &genai.Part{
333							FunctionResponse: &genai.FunctionResponse{
334								ID:       result.ToolCallID,
335								Response: response,
336								Name:     toolCall.ToolName,
337							},
338						})
339
340					case ai.ToolResultContentTypeError:
341						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
342						if !ok {
343							continue
344						}
345						response := map[string]any{"result": content.Error.Error()}
346						parts = append(parts, &genai.Part{
347							FunctionResponse: &genai.FunctionResponse{
348								ID:       result.ToolCallID,
349								Response: response,
350								Name:     toolCall.ToolName,
351							},
352						})
353
354					}
355				}
356			}
357			if len(parts) > 0 {
358				content = append(content, &genai.Content{
359					Role:  genai.RoleUser,
360					Parts: parts,
361				})
362			}
363		}
364	}
365	return systemInstructions, content, warnings
366}
367
368// Generate implements ai.LanguageModel.
369func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
370	// params, err := g.prepareParams(call)
371	// if err != nil {
372	// 	return nil, err
373	// }
374	return nil, errors.New("unimplemented")
375}
376
377// Model implements ai.LanguageModel.
378func (g *languageModel) Model() string {
379	return g.modelID
380}
381
382// Provider implements ai.LanguageModel.
383func (g *languageModel) Provider() string {
384	return g.provider
385}
386
387// Stream implements ai.LanguageModel.
388func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
389	return nil, errors.New("unimplemented")
390}
391
392func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
393	for _, tool := range tools {
394		if tool.GetType() == ai.ToolTypeFunction {
395			ft, ok := tool.(ai.FunctionTool)
396			if !ok {
397				continue
398			}
399
400			required := []string{}
401			var properties map[string]any
402			if props, ok := ft.InputSchema["properties"]; ok {
403				properties, _ = props.(map[string]any)
404			}
405			if req, ok := ft.InputSchema["required"]; ok {
406				if reqArr, ok := req.([]string); ok {
407					required = reqArr
408				}
409			}
410			declaration := &genai.FunctionDeclaration{
411				Name:        ft.Name,
412				Description: ft.Description,
413				Parameters: &genai.Schema{
414					Type:       genai.TypeObject,
415					Properties: convertSchemaProperties(properties),
416					Required:   required,
417				},
418			}
419			googleTools = append(googleTools, declaration)
420			continue
421		}
422		// TODO: handle provider tool calls
423		warnings = append(warnings, ai.CallWarning{
424			Type:    ai.CallWarningTypeUnsupportedTool,
425			Tool:    tool,
426			Message: "tool is not supported",
427		})
428	}
429	if toolChoice == nil {
430		return
431	}
432	switch *toolChoice {
433	case ai.ToolChoiceAuto:
434		googleToolChoice = &genai.ToolConfig{
435			FunctionCallingConfig: &genai.FunctionCallingConfig{
436				Mode: genai.FunctionCallingConfigModeAuto,
437			},
438		}
439	case ai.ToolChoiceRequired:
440		googleToolChoice = &genai.ToolConfig{
441			FunctionCallingConfig: &genai.FunctionCallingConfig{
442				Mode: genai.FunctionCallingConfigModeAny,
443			},
444		}
445	case ai.ToolChoiceNone:
446		googleToolChoice = &genai.ToolConfig{
447			FunctionCallingConfig: &genai.FunctionCallingConfig{
448				Mode: genai.FunctionCallingConfigModeNone,
449			},
450		}
451	default:
452		googleToolChoice = &genai.ToolConfig{
453			FunctionCallingConfig: &genai.FunctionCallingConfig{
454				Mode: genai.FunctionCallingConfigModeAny,
455				AllowedFunctionNames: []string{
456					string(*toolChoice),
457				},
458			},
459		}
460	}
461	return
462}
463
464func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
465	properties := make(map[string]*genai.Schema)
466
467	for name, param := range parameters {
468		properties[name] = convertToSchema(param)
469	}
470
471	return properties
472}
473
474func convertToSchema(param any) *genai.Schema {
475	schema := &genai.Schema{Type: genai.TypeString}
476
477	paramMap, ok := param.(map[string]any)
478	if !ok {
479		return schema
480	}
481
482	if desc, ok := paramMap["description"].(string); ok {
483		schema.Description = desc
484	}
485
486	typeVal, hasType := paramMap["type"]
487	if !hasType {
488		return schema
489	}
490
491	typeStr, ok := typeVal.(string)
492	if !ok {
493		return schema
494	}
495
496	schema.Type = mapJSONTypeToGoogle(typeStr)
497
498	switch typeStr {
499	case "array":
500		schema.Items = processArrayItems(paramMap)
501	case "object":
502		if props, ok := paramMap["properties"].(map[string]any); ok {
503			schema.Properties = convertSchemaProperties(props)
504		}
505	}
506
507	return schema
508}
509
510func processArrayItems(paramMap map[string]any) *genai.Schema {
511	items, ok := paramMap["items"].(map[string]any)
512	if !ok {
513		return nil
514	}
515
516	return convertToSchema(items)
517}
518
519func mapJSONTypeToGoogle(jsonType string) genai.Type {
520	switch jsonType {
521	case "string":
522		return genai.TypeString
523	case "number":
524		return genai.TypeNumber
525	case "integer":
526		return genai.TypeInteger
527	case "boolean":
528		return genai.TypeBoolean
529	case "array":
530		return genai.TypeArray
531	case "object":
532		return genai.TypeObject
533	default:
534		return genai.TypeString // Default to string for unknown types
535	}
536}