google.go

  1package google
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"maps"
 10	"net/http"
 11	"strings"
 12
 13	"github.com/charmbracelet/fantasy/ai"
 14	"github.com/charmbracelet/x/exp/slice"
 15	"google.golang.org/genai"
 16)
 17
 18type provider struct {
 19	options options
 20}
 21
 22type options struct {
 23	apiKey  string
 24	name    string
 25	headers map[string]string
 26	client  *http.Client
 27}
 28
 29type Option = func(*options)
 30
 31func New(opts ...Option) ai.Provider {
 32	options := options{
 33		headers: map[string]string{},
 34	}
 35	for _, o := range opts {
 36		o(&options)
 37	}
 38
 39	if options.name == "" {
 40		options.name = "google"
 41	}
 42
 43	return &provider{
 44		options: options,
 45	}
 46}
 47
 48func WithAPIKey(apiKey string) Option {
 49	return func(o *options) {
 50		o.apiKey = apiKey
 51	}
 52}
 53
 54func WithName(name string) Option {
 55	return func(o *options) {
 56		o.name = name
 57	}
 58}
 59
 60func WithHeaders(headers map[string]string) Option {
 61	return func(o *options) {
 62		maps.Copy(o.headers, headers)
 63	}
 64}
 65
 66func WithHTTPClient(client *http.Client) Option {
 67	return func(o *options) {
 68		o.client = client
 69	}
 70}
 71
 72type languageModel struct {
 73	provider        string
 74	modelID         string
 75	client          *genai.Client
 76	providerOptions options
 77}
 78
 79// LanguageModel implements ai.Provider.
 80func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 81	cc := &genai.ClientConfig{
 82		APIKey:     g.options.apiKey,
 83		Backend:    genai.BackendGeminiAPI,
 84		HTTPClient: g.options.client,
 85	}
 86	client, err := genai.NewClient(context.Background(), cc)
 87	if err != nil {
 88		return nil, err
 89	}
 90	return &languageModel{
 91		modelID:         modelID,
 92		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 93		providerOptions: g.options,
 94		client:          client,
 95	}, nil
 96}
 97
 98func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
 99	config := &genai.GenerateContentConfig{}
100	providerOptions := &providerOptions{}
101	if v, ok := call.ProviderOptions["google"]; ok {
102		err := ai.ParseOptions(v, providerOptions)
103		if err != nil {
104			return nil, nil, nil, err
105		}
106	}
107
108	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
109
110	if providerOptions.ThinkingConfig != nil &&
111		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
112		*providerOptions.ThinkingConfig.IncludeThoughts &&
113		strings.HasPrefix(a.provider, "google.vertex.") {
114		warnings = append(warnings, ai.CallWarning{
115			Type: ai.CallWarningTypeOther,
116			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
117				"and might not be supported or could behave unexpectedly with the current Google provider " +
118				fmt.Sprintf("(%s)", a.provider),
119		})
120	}
121
122	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
123
124	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
125		if len(content) > 0 && content[0].Role == genai.RoleUser {
126			systemParts := []string{}
127			for _, sp := range systemInstructions.Parts {
128				systemParts = append(systemParts, sp.Text)
129			}
130			systemMsg := strings.Join(systemParts, "\n")
131			content[0].Parts = append([]*genai.Part{
132				{
133					Text: systemMsg + "\n\n",
134				},
135			}, content[0].Parts...)
136			systemInstructions = nil
137		}
138	}
139
140	config.SystemInstruction = systemInstructions
141
142	if call.MaxOutputTokens != nil {
143		config.MaxOutputTokens = int32(*call.MaxOutputTokens)
144	}
145
146	if call.Temperature != nil {
147		tmp := float32(*call.Temperature)
148		config.Temperature = &tmp
149	}
150	if call.TopK != nil {
151		tmp := float32(*call.TopK)
152		config.TopK = &tmp
153	}
154	if call.TopP != nil {
155		tmp := float32(*call.TopP)
156		config.TopP = &tmp
157	}
158	if call.FrequencyPenalty != nil {
159		tmp := float32(*call.FrequencyPenalty)
160		config.FrequencyPenalty = &tmp
161	}
162	if call.PresencePenalty != nil {
163		tmp := float32(*call.PresencePenalty)
164		config.PresencePenalty = &tmp
165	}
166
167	if providerOptions.ThinkingConfig != nil {
168		config.ThinkingConfig = &genai.ThinkingConfig{}
169		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
170			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
171		}
172		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
173			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
174			config.ThinkingConfig.ThinkingBudget = &tmp
175		}
176	}
177	for _, safetySetting := range providerOptions.SafetySettings {
178		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
179			Category:  genai.HarmCategory(safetySetting.Category),
180			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
181		})
182	}
183	if providerOptions.CachedContent != "" {
184		config.CachedContent = providerOptions.CachedContent
185	}
186
187	if len(call.Tools) > 0 {
188		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
189		config.ToolConfig = toolChoice
190		config.Tools = append(config.Tools, &genai.Tool{
191			FunctionDeclarations: tools,
192		})
193		warnings = append(warnings, toolWarnings...)
194	}
195
196	return config, content, warnings, nil
197}
198
199func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
200	var systemInstructions *genai.Content
201	var content []*genai.Content
202	var warnings []ai.CallWarning
203
204	finishedSystemBlock := false
205	for _, msg := range prompt {
206		switch msg.Role {
207		case ai.MessageRoleSystem:
208			if finishedSystemBlock {
209				// skip multiple system messages that are separated by user/assistant messages
210				// TODO: see if we need to send error here?
211				continue
212			}
213			finishedSystemBlock = true
214
215			var systemMessages []string
216			for _, part := range msg.Content {
217				text, ok := ai.AsMessagePart[ai.TextPart](part)
218				if !ok || text.Text == "" {
219					continue
220				}
221				systemMessages = append(systemMessages, text.Text)
222			}
223			if len(systemMessages) > 0 {
224				systemInstructions = &genai.Content{
225					Parts: []*genai.Part{
226						{
227							Text: strings.Join(systemMessages, "\n"),
228						},
229					},
230				}
231			}
232		case ai.MessageRoleUser:
233			var parts []*genai.Part
234			for _, part := range msg.Content {
235				switch part.GetType() {
236				case ai.ContentTypeText:
237					text, ok := ai.AsMessagePart[ai.TextPart](part)
238					if !ok || text.Text == "" {
239						continue
240					}
241					parts = append(parts, &genai.Part{
242						Text: text.Text,
243					})
244				case ai.ContentTypeFile:
245					file, ok := ai.AsMessagePart[ai.FilePart](part)
246					if !ok {
247						continue
248					}
249					var encoded []byte
250					base64.StdEncoding.Encode(encoded, file.Data)
251					parts = append(parts, &genai.Part{
252						InlineData: &genai.Blob{
253							Data:     encoded,
254							MIMEType: file.MediaType,
255						},
256					})
257				}
258			}
259			if len(parts) > 0 {
260				content = append(content, &genai.Content{
261					Role:  genai.RoleUser,
262					Parts: parts,
263				})
264			}
265		case ai.MessageRoleAssistant:
266			var parts []*genai.Part
267			for _, part := range msg.Content {
268				switch part.GetType() {
269				case ai.ContentTypeText:
270					text, ok := ai.AsMessagePart[ai.TextPart](part)
271					if !ok || text.Text == "" {
272						continue
273					}
274					parts = append(parts, &genai.Part{
275						Text: text.Text,
276					})
277				case ai.ContentTypeToolCall:
278					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
279					if !ok {
280						continue
281					}
282
283					var result map[string]any
284					err := json.Unmarshal([]byte(toolCall.Input), &result)
285					if err != nil {
286						continue
287					}
288					parts = append(parts, &genai.Part{
289						FunctionCall: &genai.FunctionCall{
290							ID:   toolCall.ToolCallID,
291							Name: toolCall.ToolName,
292							Args: result,
293						},
294					})
295				}
296			}
297			if len(parts) > 0 {
298				content = append(content, &genai.Content{
299					Role:  genai.RoleModel,
300					Parts: parts,
301				})
302			}
303		case ai.MessageRoleTool:
304			var parts []*genai.Part
305			for _, part := range msg.Content {
306				switch part.GetType() {
307				case ai.ContentTypeToolResult:
308					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
309					if !ok {
310						continue
311					}
312					var toolCall ai.ToolCallPart
313					for _, m := range prompt {
314						if m.Role == ai.MessageRoleAssistant {
315							for _, content := range m.Content {
316								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
317								if !ok {
318									continue
319								}
320								if tc.ToolCallID == result.ToolCallID {
321									toolCall = tc
322									break
323								}
324							}
325						}
326					}
327					switch result.Output.GetType() {
328					case ai.ToolResultContentTypeText:
329						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
330						if !ok {
331							continue
332						}
333						response := map[string]any{"result": content.Text}
334						parts = append(parts, &genai.Part{
335							FunctionResponse: &genai.FunctionResponse{
336								ID:       result.ToolCallID,
337								Response: response,
338								Name:     toolCall.ToolName,
339							},
340						})
341
342					case ai.ToolResultContentTypeError:
343						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
344						if !ok {
345							continue
346						}
347						response := map[string]any{"result": content.Error.Error()}
348						parts = append(parts, &genai.Part{
349							FunctionResponse: &genai.FunctionResponse{
350								ID:       result.ToolCallID,
351								Response: response,
352								Name:     toolCall.ToolName,
353							},
354						})
355
356					}
357				}
358			}
359			if len(parts) > 0 {
360				content = append(content, &genai.Content{
361					Role:  genai.RoleUser,
362					Parts: parts,
363				})
364			}
365		}
366	}
367	return systemInstructions, content, warnings
368}
369
370// Generate implements ai.LanguageModel.
371func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
372	config, contents, warnings, err := g.prepareParams(call)
373	if err != nil {
374		return nil, err
375	}
376
377	lastMessage, history, ok := slice.Pop(contents)
378	if !ok {
379		return nil, errors.New("no messages to send")
380	}
381
382	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
383	if err != nil {
384		return nil, err
385	}
386
387	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
388	if err != nil {
389		return nil, err
390	}
391
392	return mapResponse(response, warnings)
393}
394
395// Model implements ai.LanguageModel.
396func (g *languageModel) Model() string {
397	return g.modelID
398}
399
400// Provider implements ai.LanguageModel.
401func (g *languageModel) Provider() string {
402	return g.provider
403}
404
405// Stream implements ai.LanguageModel.
406func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
407	return nil, errors.New("unimplemented")
408}
409
410func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
411	for _, tool := range tools {
412		if tool.GetType() == ai.ToolTypeFunction {
413			ft, ok := tool.(ai.FunctionTool)
414			if !ok {
415				continue
416			}
417
418			required := []string{}
419			var properties map[string]any
420			if props, ok := ft.InputSchema["properties"]; ok {
421				properties, _ = props.(map[string]any)
422			}
423			if req, ok := ft.InputSchema["required"]; ok {
424				if reqArr, ok := req.([]string); ok {
425					required = reqArr
426				}
427			}
428			declaration := &genai.FunctionDeclaration{
429				Name:        ft.Name,
430				Description: ft.Description,
431				Parameters: &genai.Schema{
432					Type:       genai.TypeObject,
433					Properties: convertSchemaProperties(properties),
434					Required:   required,
435				},
436			}
437			googleTools = append(googleTools, declaration)
438			continue
439		}
440		// TODO: handle provider tool calls
441		warnings = append(warnings, ai.CallWarning{
442			Type:    ai.CallWarningTypeUnsupportedTool,
443			Tool:    tool,
444			Message: "tool is not supported",
445		})
446	}
447	if toolChoice == nil {
448		return
449	}
450	switch *toolChoice {
451	case ai.ToolChoiceAuto:
452		googleToolChoice = &genai.ToolConfig{
453			FunctionCallingConfig: &genai.FunctionCallingConfig{
454				Mode: genai.FunctionCallingConfigModeAuto,
455			},
456		}
457	case ai.ToolChoiceRequired:
458		googleToolChoice = &genai.ToolConfig{
459			FunctionCallingConfig: &genai.FunctionCallingConfig{
460				Mode: genai.FunctionCallingConfigModeAny,
461			},
462		}
463	case ai.ToolChoiceNone:
464		googleToolChoice = &genai.ToolConfig{
465			FunctionCallingConfig: &genai.FunctionCallingConfig{
466				Mode: genai.FunctionCallingConfigModeNone,
467			},
468		}
469	default:
470		googleToolChoice = &genai.ToolConfig{
471			FunctionCallingConfig: &genai.FunctionCallingConfig{
472				Mode: genai.FunctionCallingConfigModeAny,
473				AllowedFunctionNames: []string{
474					string(*toolChoice),
475				},
476			},
477		}
478	}
479	return
480}
481
482func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
483	properties := make(map[string]*genai.Schema)
484
485	for name, param := range parameters {
486		properties[name] = convertToSchema(param)
487	}
488
489	return properties
490}
491
492func convertToSchema(param any) *genai.Schema {
493	schema := &genai.Schema{Type: genai.TypeString}
494
495	paramMap, ok := param.(map[string]any)
496	if !ok {
497		return schema
498	}
499
500	if desc, ok := paramMap["description"].(string); ok {
501		schema.Description = desc
502	}
503
504	typeVal, hasType := paramMap["type"]
505	if !hasType {
506		return schema
507	}
508
509	typeStr, ok := typeVal.(string)
510	if !ok {
511		return schema
512	}
513
514	schema.Type = mapJSONTypeToGoogle(typeStr)
515
516	switch typeStr {
517	case "array":
518		schema.Items = processArrayItems(paramMap)
519	case "object":
520		if props, ok := paramMap["properties"].(map[string]any); ok {
521			schema.Properties = convertSchemaProperties(props)
522		}
523	}
524
525	return schema
526}
527
528func processArrayItems(paramMap map[string]any) *genai.Schema {
529	items, ok := paramMap["items"].(map[string]any)
530	if !ok {
531		return nil
532	}
533
534	return convertToSchema(items)
535}
536
537func mapJSONTypeToGoogle(jsonType string) genai.Type {
538	switch jsonType {
539	case "string":
540		return genai.TypeString
541	case "number":
542		return genai.TypeNumber
543	case "integer":
544		return genai.TypeInteger
545	case "boolean":
546		return genai.TypeBoolean
547	case "array":
548		return genai.TypeArray
549	case "object":
550		return genai.TypeObject
551	default:
552		return genai.TypeString // Default to string for unknown types
553	}
554}
555
556func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
557	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
558		return nil, errors.New("no response from model")
559	}
560
561	var (
562		content      []ai.Content
563		finishReason ai.FinishReason
564		hasToolCalls bool
565		candidate    = response.Candidates[0]
566	)
567
568	for _, part := range candidate.Content.Parts {
569		switch {
570		case part.Text != "":
571			content = append(content, ai.TextContent{Text: part.Text})
572		case part.FunctionCall != nil:
573			input, err := json.Marshal(part.FunctionCall.Args)
574			if err != nil {
575				return nil, err
576			}
577			content = append(content, ai.ToolCallContent{
578				ToolCallID:       part.FunctionCall.ID,
579				ToolName:         part.FunctionCall.Name,
580				Input:            string(input),
581				ProviderExecuted: false,
582			})
583			hasToolCalls = true
584		default:
585			return nil, fmt.Errorf("not implemented part type")
586		}
587	}
588
589	if hasToolCalls {
590		finishReason = ai.FinishReasonToolCalls
591	} else {
592		finishReason = mapFinishReason(candidate.FinishReason)
593	}
594
595	return &ai.Response{
596		Content:      content,
597		Usage:        mapUsage(response.UsageMetadata),
598		FinishReason: finishReason,
599		Warnings:     warnings,
600	}, nil
601}
602
603func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
604	switch reason {
605	case genai.FinishReasonStop:
606		return ai.FinishReasonStop
607	case genai.FinishReasonMaxTokens:
608		return ai.FinishReasonLength
609	case genai.FinishReasonSafety,
610		genai.FinishReasonBlocklist,
611		genai.FinishReasonProhibitedContent,
612		genai.FinishReasonSPII,
613		genai.FinishReasonImageSafety:
614		return ai.FinishReasonContentFilter
615	case genai.FinishReasonRecitation,
616		genai.FinishReasonLanguage,
617		genai.FinishReasonMalformedFunctionCall:
618		return ai.FinishReasonError
619	case genai.FinishReasonOther:
620		return ai.FinishReasonOther
621	default:
622		return ai.FinishReasonUnknown
623	}
624}
625
626func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
627	return ai.Usage{
628		InputTokens:         int64(usage.ToolUsePromptTokenCount),
629		OutputTokens:        int64(usage.CandidatesTokenCount),
630		TotalTokens:         int64(usage.TotalTokenCount),
631		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
632		CacheCreationTokens: int64(usage.CachedContentTokenCount),
633		CacheReadTokens:     0,
634	}
635}