google.go

  1package google
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/base64"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"maps"
 11	"net/http"
 12	"strings"
 13
 14	"github.com/charmbracelet/fantasy/ai"
 15	"github.com/charmbracelet/x/exp/slice"
 16	"github.com/google/uuid"
 17	"google.golang.org/genai"
 18)
 19
 20type provider struct {
 21	options options
 22}
 23
 24type options struct {
 25	apiKey  string
 26	name    string
 27	headers map[string]string
 28	client  *http.Client
 29}
 30
 31type Option = func(*options)
 32
 33func New(opts ...Option) ai.Provider {
 34	options := options{
 35		headers: map[string]string{},
 36	}
 37	for _, o := range opts {
 38		o(&options)
 39	}
 40
 41	if options.name == "" {
 42		options.name = "google"
 43	}
 44
 45	return &provider{
 46		options: options,
 47	}
 48}
 49
 50func WithAPIKey(apiKey string) Option {
 51	return func(o *options) {
 52		o.apiKey = apiKey
 53	}
 54}
 55
 56func WithName(name string) Option {
 57	return func(o *options) {
 58		o.name = name
 59	}
 60}
 61
 62func WithHeaders(headers map[string]string) Option {
 63	return func(o *options) {
 64		maps.Copy(o.headers, headers)
 65	}
 66}
 67
 68func WithHTTPClient(client *http.Client) Option {
 69	return func(o *options) {
 70		o.client = client
 71	}
 72}
 73
 74type languageModel struct {
 75	provider        string
 76	modelID         string
 77	client          *genai.Client
 78	providerOptions options
 79}
 80
 81// LanguageModel implements ai.Provider.
 82func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 83	cc := &genai.ClientConfig{
 84		APIKey:     g.options.apiKey,
 85		Backend:    genai.BackendGeminiAPI,
 86		HTTPClient: g.options.client,
 87	}
 88	client, err := genai.NewClient(context.Background(), cc)
 89	if err != nil {
 90		return nil, err
 91	}
 92	return &languageModel{
 93		modelID:         modelID,
 94		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 95		providerOptions: g.options,
 96		client:          client,
 97	}, nil
 98}
 99
100func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
101	config := &genai.GenerateContentConfig{}
102	providerOptions := &providerOptions{}
103	if v, ok := call.ProviderOptions["google"]; ok {
104		err := ai.ParseOptions(v, providerOptions)
105		if err != nil {
106			return nil, nil, nil, err
107		}
108	}
109
110	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
111
112	if providerOptions.ThinkingConfig != nil &&
113		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
114		*providerOptions.ThinkingConfig.IncludeThoughts &&
115		strings.HasPrefix(a.provider, "google.vertex.") {
116		warnings = append(warnings, ai.CallWarning{
117			Type: ai.CallWarningTypeOther,
118			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
119				"and might not be supported or could behave unexpectedly with the current Google provider " +
120				fmt.Sprintf("(%s)", a.provider),
121		})
122	}
123
124	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
125
126	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
127		if len(content) > 0 && content[0].Role == genai.RoleUser {
128			systemParts := []string{}
129			for _, sp := range systemInstructions.Parts {
130				systemParts = append(systemParts, sp.Text)
131			}
132			systemMsg := strings.Join(systemParts, "\n")
133			content[0].Parts = append([]*genai.Part{
134				{
135					Text: systemMsg + "\n\n",
136				},
137			}, content[0].Parts...)
138			systemInstructions = nil
139		}
140	}
141
142	config.SystemInstruction = systemInstructions
143
144	if call.MaxOutputTokens != nil {
145		config.MaxOutputTokens = int32(*call.MaxOutputTokens)
146	}
147
148	if call.Temperature != nil {
149		tmp := float32(*call.Temperature)
150		config.Temperature = &tmp
151	}
152	if call.TopK != nil {
153		tmp := float32(*call.TopK)
154		config.TopK = &tmp
155	}
156	if call.TopP != nil {
157		tmp := float32(*call.TopP)
158		config.TopP = &tmp
159	}
160	if call.FrequencyPenalty != nil {
161		tmp := float32(*call.FrequencyPenalty)
162		config.FrequencyPenalty = &tmp
163	}
164	if call.PresencePenalty != nil {
165		tmp := float32(*call.PresencePenalty)
166		config.PresencePenalty = &tmp
167	}
168
169	if providerOptions.ThinkingConfig != nil {
170		config.ThinkingConfig = &genai.ThinkingConfig{}
171		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
172			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
173		}
174		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
175			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
176			config.ThinkingConfig.ThinkingBudget = &tmp
177		}
178	}
179	for _, safetySetting := range providerOptions.SafetySettings {
180		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
181			Category:  genai.HarmCategory(safetySetting.Category),
182			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
183		})
184	}
185	if providerOptions.CachedContent != "" {
186		config.CachedContent = providerOptions.CachedContent
187	}
188
189	if len(call.Tools) > 0 {
190		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
191		config.ToolConfig = toolChoice
192		config.Tools = append(config.Tools, &genai.Tool{
193			FunctionDeclarations: tools,
194		})
195		warnings = append(warnings, toolWarnings...)
196	}
197
198	return config, content, warnings, nil
199}
200
201func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
202	var systemInstructions *genai.Content
203	var content []*genai.Content
204	var warnings []ai.CallWarning
205
206	finishedSystemBlock := false
207	for _, msg := range prompt {
208		switch msg.Role {
209		case ai.MessageRoleSystem:
210			if finishedSystemBlock {
211				// skip multiple system messages that are separated by user/assistant messages
212				// TODO: see if we need to send error here?
213				continue
214			}
215			finishedSystemBlock = true
216
217			var systemMessages []string
218			for _, part := range msg.Content {
219				text, ok := ai.AsMessagePart[ai.TextPart](part)
220				if !ok || text.Text == "" {
221					continue
222				}
223				systemMessages = append(systemMessages, text.Text)
224			}
225			if len(systemMessages) > 0 {
226				systemInstructions = &genai.Content{
227					Parts: []*genai.Part{
228						{
229							Text: strings.Join(systemMessages, "\n"),
230						},
231					},
232				}
233			}
234		case ai.MessageRoleUser:
235			var parts []*genai.Part
236			for _, part := range msg.Content {
237				switch part.GetType() {
238				case ai.ContentTypeText:
239					text, ok := ai.AsMessagePart[ai.TextPart](part)
240					if !ok || text.Text == "" {
241						continue
242					}
243					parts = append(parts, &genai.Part{
244						Text: text.Text,
245					})
246				case ai.ContentTypeFile:
247					file, ok := ai.AsMessagePart[ai.FilePart](part)
248					if !ok {
249						continue
250					}
251					var encoded []byte
252					base64.StdEncoding.Encode(encoded, file.Data)
253					parts = append(parts, &genai.Part{
254						InlineData: &genai.Blob{
255							Data:     encoded,
256							MIMEType: file.MediaType,
257						},
258					})
259				}
260			}
261			if len(parts) > 0 {
262				content = append(content, &genai.Content{
263					Role:  genai.RoleUser,
264					Parts: parts,
265				})
266			}
267		case ai.MessageRoleAssistant:
268			var parts []*genai.Part
269			for _, part := range msg.Content {
270				switch part.GetType() {
271				case ai.ContentTypeText:
272					text, ok := ai.AsMessagePart[ai.TextPart](part)
273					if !ok || text.Text == "" {
274						continue
275					}
276					parts = append(parts, &genai.Part{
277						Text: text.Text,
278					})
279				case ai.ContentTypeToolCall:
280					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
281					if !ok {
282						continue
283					}
284
285					var result map[string]any
286					err := json.Unmarshal([]byte(toolCall.Input), &result)
287					if err != nil {
288						continue
289					}
290					parts = append(parts, &genai.Part{
291						FunctionCall: &genai.FunctionCall{
292							ID:   toolCall.ToolCallID,
293							Name: toolCall.ToolName,
294							Args: result,
295						},
296					})
297				}
298			}
299			if len(parts) > 0 {
300				content = append(content, &genai.Content{
301					Role:  genai.RoleModel,
302					Parts: parts,
303				})
304			}
305		case ai.MessageRoleTool:
306			var parts []*genai.Part
307			for _, part := range msg.Content {
308				switch part.GetType() {
309				case ai.ContentTypeToolResult:
310					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
311					if !ok {
312						continue
313					}
314					var toolCall ai.ToolCallPart
315					for _, m := range prompt {
316						if m.Role == ai.MessageRoleAssistant {
317							for _, content := range m.Content {
318								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
319								if !ok {
320									continue
321								}
322								if tc.ToolCallID == result.ToolCallID {
323									toolCall = tc
324									break
325								}
326							}
327						}
328					}
329					switch result.Output.GetType() {
330					case ai.ToolResultContentTypeText:
331						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
332						if !ok {
333							continue
334						}
335						response := map[string]any{"result": content.Text}
336						parts = append(parts, &genai.Part{
337							FunctionResponse: &genai.FunctionResponse{
338								ID:       result.ToolCallID,
339								Response: response,
340								Name:     toolCall.ToolName,
341							},
342						})
343
344					case ai.ToolResultContentTypeError:
345						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
346						if !ok {
347							continue
348						}
349						response := map[string]any{"result": content.Error.Error()}
350						parts = append(parts, &genai.Part{
351							FunctionResponse: &genai.FunctionResponse{
352								ID:       result.ToolCallID,
353								Response: response,
354								Name:     toolCall.ToolName,
355							},
356						})
357
358					}
359				}
360			}
361			if len(parts) > 0 {
362				content = append(content, &genai.Content{
363					Role:  genai.RoleUser,
364					Parts: parts,
365				})
366			}
367		}
368	}
369	return systemInstructions, content, warnings
370}
371
372// Generate implements ai.LanguageModel.
373func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
374	config, contents, warnings, err := g.prepareParams(call)
375	if err != nil {
376		return nil, err
377	}
378
379	lastMessage, history, ok := slice.Pop(contents)
380	if !ok {
381		return nil, errors.New("no messages to send")
382	}
383
384	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
385	if err != nil {
386		return nil, err
387	}
388
389	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
390	if err != nil {
391		return nil, err
392	}
393
394	return mapResponse(response, warnings)
395}
396
397// Model implements ai.LanguageModel.
398func (g *languageModel) Model() string {
399	return g.modelID
400}
401
402// Provider implements ai.LanguageModel.
403func (g *languageModel) Provider() string {
404	return g.provider
405}
406
407// Stream implements ai.LanguageModel.
408func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
409	config, contents, warnings, err := g.prepareParams(call)
410	if err != nil {
411		return nil, err
412	}
413
414	lastMessage, history, ok := slice.Pop(contents)
415	if !ok {
416		return nil, errors.New("no messages to send")
417	}
418
419	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
420	if err != nil {
421		return nil, err
422	}
423
424	return func(yield func(ai.StreamPart) bool) {
425		if len(warnings) > 0 {
426			if !yield(ai.StreamPart{
427				Type:     ai.StreamPartTypeWarnings,
428				Warnings: warnings,
429			}) {
430				return
431			}
432		}
433
434		var currentContent string
435		var toolCalls []ai.ToolCallContent
436		var isActiveText bool
437		var usage ai.Usage
438
439		// Stream the response
440		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
441			if err != nil {
442				yield(ai.StreamPart{
443					Type:  ai.StreamPartTypeError,
444					Error: err,
445				})
446				return
447			}
448
449			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
450				for _, part := range resp.Candidates[0].Content.Parts {
451					switch {
452					case part.Text != "":
453						delta := part.Text
454						if delta != "" {
455							if !isActiveText {
456								isActiveText = true
457								if !yield(ai.StreamPart{
458									Type: ai.StreamPartTypeTextStart,
459									ID:   "0",
460								}) {
461									return
462								}
463							}
464							if !yield(ai.StreamPart{
465								Type:  ai.StreamPartTypeTextDelta,
466								ID:    "0",
467								Delta: delta,
468							}) {
469								return
470							}
471							currentContent += delta
472						}
473					case part.FunctionCall != nil:
474						if isActiveText {
475							isActiveText = false
476							if !yield(ai.StreamPart{
477								Type: ai.StreamPartTypeTextEnd,
478								ID:   "0",
479							}) {
480								return
481							}
482						}
483
484						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
485
486						args, err := json.Marshal(part.FunctionCall.Args)
487						if err != nil {
488							yield(ai.StreamPart{
489								Type:  ai.StreamPartTypeError,
490								Error: err,
491							})
492							return
493						}
494
495						if !yield(ai.StreamPart{
496							Type:         ai.StreamPartTypeToolInputStart,
497							ID:           toolCallID,
498							ToolCallName: part.FunctionCall.Name,
499						}) {
500							return
501						}
502
503						if !yield(ai.StreamPart{
504							Type:  ai.StreamPartTypeToolInputDelta,
505							ID:    toolCallID,
506							Delta: string(args),
507						}) {
508							return
509						}
510
511						if !yield(ai.StreamPart{
512							Type: ai.StreamPartTypeToolInputEnd,
513							ID:   toolCallID,
514						}) {
515							return
516						}
517
518						if !yield(ai.StreamPart{
519							Type:             ai.StreamPartTypeToolCall,
520							ID:               toolCallID,
521							ToolCallName:     part.FunctionCall.Name,
522							ToolCallInput:    string(args),
523							ProviderExecuted: false,
524						}) {
525							return
526						}
527
528						toolCalls = append(toolCalls, ai.ToolCallContent{
529							ToolCallID:       toolCallID,
530							ToolName:         part.FunctionCall.Name,
531							Input:            string(args),
532							ProviderExecuted: false,
533						})
534					}
535				}
536			}
537
538			if resp.UsageMetadata != nil {
539				usage = mapUsage(resp.UsageMetadata)
540			}
541		}
542
543		if isActiveText {
544			if !yield(ai.StreamPart{
545				Type: ai.StreamPartTypeTextEnd,
546				ID:   "0",
547			}) {
548				return
549			}
550		}
551
552		finishReason := ai.FinishReasonStop
553		if len(toolCalls) > 0 {
554			finishReason = ai.FinishReasonToolCalls
555		}
556
557		yield(ai.StreamPart{
558			Type:         ai.StreamPartTypeFinish,
559			Usage:        usage,
560			FinishReason: finishReason,
561		})
562	}, nil
563}
564
565func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
566	for _, tool := range tools {
567		if tool.GetType() == ai.ToolTypeFunction {
568			ft, ok := tool.(ai.FunctionTool)
569			if !ok {
570				continue
571			}
572
573			required := []string{}
574			var properties map[string]any
575			if props, ok := ft.InputSchema["properties"]; ok {
576				properties, _ = props.(map[string]any)
577			}
578			if req, ok := ft.InputSchema["required"]; ok {
579				if reqArr, ok := req.([]string); ok {
580					required = reqArr
581				}
582			}
583			declaration := &genai.FunctionDeclaration{
584				Name:        ft.Name,
585				Description: ft.Description,
586				Parameters: &genai.Schema{
587					Type:       genai.TypeObject,
588					Properties: convertSchemaProperties(properties),
589					Required:   required,
590				},
591			}
592			googleTools = append(googleTools, declaration)
593			continue
594		}
595		// TODO: handle provider tool calls
596		warnings = append(warnings, ai.CallWarning{
597			Type:    ai.CallWarningTypeUnsupportedTool,
598			Tool:    tool,
599			Message: "tool is not supported",
600		})
601	}
602	if toolChoice == nil {
603		return
604	}
605	switch *toolChoice {
606	case ai.ToolChoiceAuto:
607		googleToolChoice = &genai.ToolConfig{
608			FunctionCallingConfig: &genai.FunctionCallingConfig{
609				Mode: genai.FunctionCallingConfigModeAuto,
610			},
611		}
612	case ai.ToolChoiceRequired:
613		googleToolChoice = &genai.ToolConfig{
614			FunctionCallingConfig: &genai.FunctionCallingConfig{
615				Mode: genai.FunctionCallingConfigModeAny,
616			},
617		}
618	case ai.ToolChoiceNone:
619		googleToolChoice = &genai.ToolConfig{
620			FunctionCallingConfig: &genai.FunctionCallingConfig{
621				Mode: genai.FunctionCallingConfigModeNone,
622			},
623		}
624	default:
625		googleToolChoice = &genai.ToolConfig{
626			FunctionCallingConfig: &genai.FunctionCallingConfig{
627				Mode: genai.FunctionCallingConfigModeAny,
628				AllowedFunctionNames: []string{
629					string(*toolChoice),
630				},
631			},
632		}
633	}
634	return
635}
636
637func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
638	properties := make(map[string]*genai.Schema)
639
640	for name, param := range parameters {
641		properties[name] = convertToSchema(param)
642	}
643
644	return properties
645}
646
647func convertToSchema(param any) *genai.Schema {
648	schema := &genai.Schema{Type: genai.TypeString}
649
650	paramMap, ok := param.(map[string]any)
651	if !ok {
652		return schema
653	}
654
655	if desc, ok := paramMap["description"].(string); ok {
656		schema.Description = desc
657	}
658
659	typeVal, hasType := paramMap["type"]
660	if !hasType {
661		return schema
662	}
663
664	typeStr, ok := typeVal.(string)
665	if !ok {
666		return schema
667	}
668
669	schema.Type = mapJSONTypeToGoogle(typeStr)
670
671	switch typeStr {
672	case "array":
673		schema.Items = processArrayItems(paramMap)
674	case "object":
675		if props, ok := paramMap["properties"].(map[string]any); ok {
676			schema.Properties = convertSchemaProperties(props)
677		}
678	}
679
680	return schema
681}
682
683func processArrayItems(paramMap map[string]any) *genai.Schema {
684	items, ok := paramMap["items"].(map[string]any)
685	if !ok {
686		return nil
687	}
688
689	return convertToSchema(items)
690}
691
692func mapJSONTypeToGoogle(jsonType string) genai.Type {
693	switch jsonType {
694	case "string":
695		return genai.TypeString
696	case "number":
697		return genai.TypeNumber
698	case "integer":
699		return genai.TypeInteger
700	case "boolean":
701		return genai.TypeBoolean
702	case "array":
703		return genai.TypeArray
704	case "object":
705		return genai.TypeObject
706	default:
707		return genai.TypeString // Default to string for unknown types
708	}
709}
710
711func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
712	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
713		return nil, errors.New("no response from model")
714	}
715
716	var (
717		content      []ai.Content
718		finishReason ai.FinishReason
719		hasToolCalls bool
720		candidate    = response.Candidates[0]
721	)
722
723	for _, part := range candidate.Content.Parts {
724		switch {
725		case part.Text != "":
726			content = append(content, ai.TextContent{Text: part.Text})
727		case part.FunctionCall != nil:
728			input, err := json.Marshal(part.FunctionCall.Args)
729			if err != nil {
730				return nil, err
731			}
732			content = append(content, ai.ToolCallContent{
733				ToolCallID:       part.FunctionCall.ID,
734				ToolName:         part.FunctionCall.Name,
735				Input:            string(input),
736				ProviderExecuted: false,
737			})
738			hasToolCalls = true
739		default:
740			return nil, fmt.Errorf("not implemented part type")
741		}
742	}
743
744	if hasToolCalls {
745		finishReason = ai.FinishReasonToolCalls
746	} else {
747		finishReason = mapFinishReason(candidate.FinishReason)
748	}
749
750	return &ai.Response{
751		Content:      content,
752		Usage:        mapUsage(response.UsageMetadata),
753		FinishReason: finishReason,
754		Warnings:     warnings,
755	}, nil
756}
757
758func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
759	switch reason {
760	case genai.FinishReasonStop:
761		return ai.FinishReasonStop
762	case genai.FinishReasonMaxTokens:
763		return ai.FinishReasonLength
764	case genai.FinishReasonSafety,
765		genai.FinishReasonBlocklist,
766		genai.FinishReasonProhibitedContent,
767		genai.FinishReasonSPII,
768		genai.FinishReasonImageSafety:
769		return ai.FinishReasonContentFilter
770	case genai.FinishReasonRecitation,
771		genai.FinishReasonLanguage,
772		genai.FinishReasonMalformedFunctionCall:
773		return ai.FinishReasonError
774	case genai.FinishReasonOther:
775		return ai.FinishReasonOther
776	default:
777		return ai.FinishReasonUnknown
778	}
779}
780
781func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
782	return ai.Usage{
783		InputTokens:         int64(usage.ToolUsePromptTokenCount),
784		OutputTokens:        int64(usage.CandidatesTokenCount),
785		TotalTokens:         int64(usage.TotalTokenCount),
786		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
787		CacheCreationTokens: int64(usage.CachedContentTokenCount),
788		CacheReadTokens:     0,
789	}
790}