google.go

  1package google
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/base64"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"maps"
 11	"net/http"
 12	"strings"
 13
 14	"github.com/charmbracelet/fantasy/ai"
 15	"github.com/charmbracelet/x/exp/slice"
 16	"github.com/google/uuid"
 17	"google.golang.org/genai"
 18)
 19
 20type provider struct {
 21	options options
 22}
 23
 24type options struct {
 25	apiKey  string
 26	name    string
 27	headers map[string]string
 28	client  *http.Client
 29}
 30
 31type Option = func(*options)
 32
 33func New(opts ...Option) ai.Provider {
 34	options := options{
 35		headers: map[string]string{},
 36	}
 37	for _, o := range opts {
 38		o(&options)
 39	}
 40
 41	options.name = cmp.Or(options.name, "google")
 42
 43	return &provider{
 44		options: options,
 45	}
 46}
 47
 48func WithAPIKey(apiKey string) Option {
 49	return func(o *options) {
 50		o.apiKey = apiKey
 51	}
 52}
 53
 54func WithName(name string) Option {
 55	return func(o *options) {
 56		o.name = name
 57	}
 58}
 59
 60func WithHeaders(headers map[string]string) Option {
 61	return func(o *options) {
 62		maps.Copy(o.headers, headers)
 63	}
 64}
 65
 66func WithHTTPClient(client *http.Client) Option {
 67	return func(o *options) {
 68		o.client = client
 69	}
 70}
 71
 72type languageModel struct {
 73	provider        string
 74	modelID         string
 75	client          *genai.Client
 76	providerOptions options
 77}
 78
 79// LanguageModel implements ai.Provider.
 80func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 81	cc := &genai.ClientConfig{
 82		APIKey:     g.options.apiKey,
 83		Backend:    genai.BackendGeminiAPI,
 84		HTTPClient: g.options.client,
 85	}
 86	client, err := genai.NewClient(context.Background(), cc)
 87	if err != nil {
 88		return nil, err
 89	}
 90	return &languageModel{
 91		modelID:         modelID,
 92		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 93		providerOptions: g.options,
 94		client:          client,
 95	}, nil
 96}
 97
 98func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
 99	config := &genai.GenerateContentConfig{}
100	providerOptions := &providerOptions{}
101	if v, ok := call.ProviderOptions["google"]; ok {
102		err := ai.ParseOptions(v, providerOptions)
103		if err != nil {
104			return nil, nil, nil, err
105		}
106	}
107
108	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
109
110	if providerOptions.ThinkingConfig != nil &&
111		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
112		*providerOptions.ThinkingConfig.IncludeThoughts &&
113		strings.HasPrefix(a.provider, "google.vertex.") {
114		warnings = append(warnings, ai.CallWarning{
115			Type: ai.CallWarningTypeOther,
116			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
117				"and might not be supported or could behave unexpectedly with the current Google provider " +
118				fmt.Sprintf("(%s)", a.provider),
119		})
120	}
121
122	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
123
124	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
125		if len(content) > 0 && content[0].Role == genai.RoleUser {
126			systemParts := []string{}
127			for _, sp := range systemInstructions.Parts {
128				systemParts = append(systemParts, sp.Text)
129			}
130			systemMsg := strings.Join(systemParts, "\n")
131			content[0].Parts = append([]*genai.Part{
132				{
133					Text: systemMsg + "\n\n",
134				},
135			}, content[0].Parts...)
136			systemInstructions = nil
137		}
138	}
139
140	config.SystemInstruction = systemInstructions
141
142	if call.MaxOutputTokens != nil {
143		config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
144	}
145
146	if call.Temperature != nil {
147		tmp := float32(*call.Temperature)
148		config.Temperature = &tmp
149	}
150	if call.TopK != nil {
151		tmp := float32(*call.TopK)
152		config.TopK = &tmp
153	}
154	if call.TopP != nil {
155		tmp := float32(*call.TopP)
156		config.TopP = &tmp
157	}
158	if call.FrequencyPenalty != nil {
159		tmp := float32(*call.FrequencyPenalty)
160		config.FrequencyPenalty = &tmp
161	}
162	if call.PresencePenalty != nil {
163		tmp := float32(*call.PresencePenalty)
164		config.PresencePenalty = &tmp
165	}
166
167	if providerOptions.ThinkingConfig != nil {
168		config.ThinkingConfig = &genai.ThinkingConfig{}
169		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
170			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
171		}
172		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
173			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
174			config.ThinkingConfig.ThinkingBudget = &tmp
175		}
176	}
177	for _, safetySetting := range providerOptions.SafetySettings {
178		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
179			Category:  genai.HarmCategory(safetySetting.Category),
180			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
181		})
182	}
183	if providerOptions.CachedContent != "" {
184		config.CachedContent = providerOptions.CachedContent
185	}
186
187	if len(call.Tools) > 0 {
188		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
189		config.ToolConfig = toolChoice
190		config.Tools = append(config.Tools, &genai.Tool{
191			FunctionDeclarations: tools,
192		})
193		warnings = append(warnings, toolWarnings...)
194	}
195
196	return config, content, warnings, nil
197}
198
199func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
200	var systemInstructions *genai.Content
201	var content []*genai.Content
202	var warnings []ai.CallWarning
203
204	finishedSystemBlock := false
205	for _, msg := range prompt {
206		switch msg.Role {
207		case ai.MessageRoleSystem:
208			if finishedSystemBlock {
209				// skip multiple system messages that are separated by user/assistant messages
210				// TODO: see if we need to send error here?
211				continue
212			}
213			finishedSystemBlock = true
214
215			var systemMessages []string
216			for _, part := range msg.Content {
217				text, ok := ai.AsMessagePart[ai.TextPart](part)
218				if !ok || text.Text == "" {
219					continue
220				}
221				systemMessages = append(systemMessages, text.Text)
222			}
223			if len(systemMessages) > 0 {
224				systemInstructions = &genai.Content{
225					Parts: []*genai.Part{
226						{
227							Text: strings.Join(systemMessages, "\n"),
228						},
229					},
230				}
231			}
232		case ai.MessageRoleUser:
233			var parts []*genai.Part
234			for _, part := range msg.Content {
235				switch part.GetType() {
236				case ai.ContentTypeText:
237					text, ok := ai.AsMessagePart[ai.TextPart](part)
238					if !ok || text.Text == "" {
239						continue
240					}
241					parts = append(parts, &genai.Part{
242						Text: text.Text,
243					})
244				case ai.ContentTypeFile:
245					file, ok := ai.AsMessagePart[ai.FilePart](part)
246					if !ok {
247						continue
248					}
249					var encoded []byte
250					base64.StdEncoding.Encode(encoded, file.Data)
251					parts = append(parts, &genai.Part{
252						InlineData: &genai.Blob{
253							Data:     encoded,
254							MIMEType: file.MediaType,
255						},
256					})
257				}
258			}
259			if len(parts) > 0 {
260				content = append(content, &genai.Content{
261					Role:  genai.RoleUser,
262					Parts: parts,
263				})
264			}
265		case ai.MessageRoleAssistant:
266			var parts []*genai.Part
267			for _, part := range msg.Content {
268				switch part.GetType() {
269				case ai.ContentTypeText:
270					text, ok := ai.AsMessagePart[ai.TextPart](part)
271					if !ok || text.Text == "" {
272						continue
273					}
274					parts = append(parts, &genai.Part{
275						Text: text.Text,
276					})
277				case ai.ContentTypeToolCall:
278					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
279					if !ok {
280						continue
281					}
282
283					var result map[string]any
284					err := json.Unmarshal([]byte(toolCall.Input), &result)
285					if err != nil {
286						continue
287					}
288					parts = append(parts, &genai.Part{
289						FunctionCall: &genai.FunctionCall{
290							ID:   toolCall.ToolCallID,
291							Name: toolCall.ToolName,
292							Args: result,
293						},
294					})
295				}
296			}
297			if len(parts) > 0 {
298				content = append(content, &genai.Content{
299					Role:  genai.RoleModel,
300					Parts: parts,
301				})
302			}
303		case ai.MessageRoleTool:
304			var parts []*genai.Part
305			for _, part := range msg.Content {
306				switch part.GetType() {
307				case ai.ContentTypeToolResult:
308					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
309					if !ok {
310						continue
311					}
312					var toolCall ai.ToolCallPart
313					for _, m := range prompt {
314						if m.Role == ai.MessageRoleAssistant {
315							for _, content := range m.Content {
316								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
317								if !ok {
318									continue
319								}
320								if tc.ToolCallID == result.ToolCallID {
321									toolCall = tc
322									break
323								}
324							}
325						}
326					}
327					switch result.Output.GetType() {
328					case ai.ToolResultContentTypeText:
329						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
330						if !ok {
331							continue
332						}
333						response := map[string]any{"result": content.Text}
334						parts = append(parts, &genai.Part{
335							FunctionResponse: &genai.FunctionResponse{
336								ID:       result.ToolCallID,
337								Response: response,
338								Name:     toolCall.ToolName,
339							},
340						})
341
342					case ai.ToolResultContentTypeError:
343						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
344						if !ok {
345							continue
346						}
347						response := map[string]any{"result": content.Error.Error()}
348						parts = append(parts, &genai.Part{
349							FunctionResponse: &genai.FunctionResponse{
350								ID:       result.ToolCallID,
351								Response: response,
352								Name:     toolCall.ToolName,
353							},
354						})
355					}
356				}
357			}
358			if len(parts) > 0 {
359				content = append(content, &genai.Content{
360					Role:  genai.RoleUser,
361					Parts: parts,
362				})
363			}
364		}
365	}
366	return systemInstructions, content, warnings
367}
368
369// Generate implements ai.LanguageModel.
370func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
371	config, contents, warnings, err := g.prepareParams(call)
372	if err != nil {
373		return nil, err
374	}
375
376	lastMessage, history, ok := slice.Pop(contents)
377	if !ok {
378		return nil, errors.New("no messages to send")
379	}
380
381	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
382	if err != nil {
383		return nil, err
384	}
385
386	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
387	if err != nil {
388		return nil, err
389	}
390
391	return mapResponse(response, warnings)
392}
393
394// Model implements ai.LanguageModel.
395func (g *languageModel) Model() string {
396	return g.modelID
397}
398
399// Provider implements ai.LanguageModel.
400func (g *languageModel) Provider() string {
401	return g.provider
402}
403
404// Stream implements ai.LanguageModel.
405func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
406	config, contents, warnings, err := g.prepareParams(call)
407	if err != nil {
408		return nil, err
409	}
410
411	lastMessage, history, ok := slice.Pop(contents)
412	if !ok {
413		return nil, errors.New("no messages to send")
414	}
415
416	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
417	if err != nil {
418		return nil, err
419	}
420
421	return func(yield func(ai.StreamPart) bool) {
422		if len(warnings) > 0 {
423			if !yield(ai.StreamPart{
424				Type:     ai.StreamPartTypeWarnings,
425				Warnings: warnings,
426			}) {
427				return
428			}
429		}
430
431		var currentContent string
432		var toolCalls []ai.ToolCallContent
433		var isActiveText bool
434		var isActiveReasoning bool
435		var blockCounter int
436		var currentTextBlockID string
437		var currentReasoningBlockID string
438		var usage ai.Usage
439		var lastFinishReason ai.FinishReason
440
441		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
442			if err != nil {
443				yield(ai.StreamPart{
444					Type:  ai.StreamPartTypeError,
445					Error: err,
446				})
447				return
448			}
449
450			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
451				for _, part := range resp.Candidates[0].Content.Parts {
452					switch {
453					case part.Text != "":
454						delta := part.Text
455						if delta != "" {
456							// Check if this is a reasoning/thought part
457							if part.Thought {
458								// End any active text block before starting reasoning
459								if isActiveText {
460									isActiveText = false
461									if !yield(ai.StreamPart{
462										Type: ai.StreamPartTypeTextEnd,
463										ID:   currentTextBlockID,
464									}) {
465										return
466									}
467								}
468
469								// Start new reasoning block if not already active
470								if !isActiveReasoning {
471									isActiveReasoning = true
472									currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
473									blockCounter++
474									if !yield(ai.StreamPart{
475										Type: ai.StreamPartTypeReasoningStart,
476										ID:   currentReasoningBlockID,
477									}) {
478										return
479									}
480								}
481
482								if !yield(ai.StreamPart{
483									Type:  ai.StreamPartTypeReasoningDelta,
484									ID:    currentReasoningBlockID,
485									Delta: delta,
486								}) {
487									return
488								}
489							} else {
490								// Regular text part
491								// End any active reasoning block before starting text
492								if isActiveReasoning {
493									isActiveReasoning = false
494									if !yield(ai.StreamPart{
495										Type: ai.StreamPartTypeReasoningEnd,
496										ID:   currentReasoningBlockID,
497									}) {
498										return
499									}
500								}
501
502								// Start new text block if not already active
503								if !isActiveText {
504									isActiveText = true
505									currentTextBlockID = fmt.Sprintf("%d", blockCounter)
506									blockCounter++
507									if !yield(ai.StreamPart{
508										Type: ai.StreamPartTypeTextStart,
509										ID:   currentTextBlockID,
510									}) {
511										return
512									}
513								}
514
515								if !yield(ai.StreamPart{
516									Type:  ai.StreamPartTypeTextDelta,
517									ID:    currentTextBlockID,
518									Delta: delta,
519								}) {
520									return
521								}
522								currentContent += delta
523							}
524						}
525					case part.FunctionCall != nil:
526						// End any active text or reasoning blocks
527						if isActiveText {
528							isActiveText = false
529							if !yield(ai.StreamPart{
530								Type: ai.StreamPartTypeTextEnd,
531								ID:   currentTextBlockID,
532							}) {
533								return
534							}
535						}
536						if isActiveReasoning {
537							isActiveReasoning = false
538							if !yield(ai.StreamPart{
539								Type: ai.StreamPartTypeReasoningEnd,
540								ID:   currentReasoningBlockID,
541							}) {
542								return
543							}
544						}
545
546						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
547
548						args, err := json.Marshal(part.FunctionCall.Args)
549						if err != nil {
550							yield(ai.StreamPart{
551								Type:  ai.StreamPartTypeError,
552								Error: err,
553							})
554							return
555						}
556
557						if !yield(ai.StreamPart{
558							Type:         ai.StreamPartTypeToolInputStart,
559							ID:           toolCallID,
560							ToolCallName: part.FunctionCall.Name,
561						}) {
562							return
563						}
564
565						if !yield(ai.StreamPart{
566							Type:  ai.StreamPartTypeToolInputDelta,
567							ID:    toolCallID,
568							Delta: string(args),
569						}) {
570							return
571						}
572
573						if !yield(ai.StreamPart{
574							Type: ai.StreamPartTypeToolInputEnd,
575							ID:   toolCallID,
576						}) {
577							return
578						}
579
580						if !yield(ai.StreamPart{
581							Type:             ai.StreamPartTypeToolCall,
582							ID:               toolCallID,
583							ToolCallName:     part.FunctionCall.Name,
584							ToolCallInput:    string(args),
585							ProviderExecuted: false,
586						}) {
587							return
588						}
589
590						toolCalls = append(toolCalls, ai.ToolCallContent{
591							ToolCallID:       toolCallID,
592							ToolName:         part.FunctionCall.Name,
593							Input:            string(args),
594							ProviderExecuted: false,
595						})
596					}
597				}
598			}
599
600			if resp.UsageMetadata != nil {
601				usage = mapUsage(resp.UsageMetadata)
602			}
603
604			if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
605				lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
606			}
607		}
608
609		// Close any open blocks before finishing
610		if isActiveText {
611			if !yield(ai.StreamPart{
612				Type: ai.StreamPartTypeTextEnd,
613				ID:   currentTextBlockID,
614			}) {
615				return
616			}
617		}
618		if isActiveReasoning {
619			if !yield(ai.StreamPart{
620				Type: ai.StreamPartTypeReasoningEnd,
621				ID:   currentReasoningBlockID,
622			}) {
623				return
624			}
625		}
626
627		finishReason := lastFinishReason
628		if len(toolCalls) > 0 {
629			finishReason = ai.FinishReasonToolCalls
630		} else if finishReason == "" {
631			finishReason = ai.FinishReasonStop
632		}
633
634		yield(ai.StreamPart{
635			Type:         ai.StreamPartTypeFinish,
636			Usage:        usage,
637			FinishReason: finishReason,
638		})
639	}, nil
640}
641
642func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
643	for _, tool := range tools {
644		if tool.GetType() == ai.ToolTypeFunction {
645			ft, ok := tool.(ai.FunctionTool)
646			if !ok {
647				continue
648			}
649
650			required := []string{}
651			var properties map[string]any
652			if props, ok := ft.InputSchema["properties"]; ok {
653				properties, _ = props.(map[string]any)
654			}
655			if req, ok := ft.InputSchema["required"]; ok {
656				if reqArr, ok := req.([]string); ok {
657					required = reqArr
658				}
659			}
660			declaration := &genai.FunctionDeclaration{
661				Name:        ft.Name,
662				Description: ft.Description,
663				Parameters: &genai.Schema{
664					Type:       genai.TypeObject,
665					Properties: convertSchemaProperties(properties),
666					Required:   required,
667				},
668			}
669			googleTools = append(googleTools, declaration)
670			continue
671		}
672		// TODO: handle provider tool calls
673		warnings = append(warnings, ai.CallWarning{
674			Type:    ai.CallWarningTypeUnsupportedTool,
675			Tool:    tool,
676			Message: "tool is not supported",
677		})
678	}
679	if toolChoice == nil {
680		return //nolint: nakedret
681	}
682	switch *toolChoice {
683	case ai.ToolChoiceAuto:
684		googleToolChoice = &genai.ToolConfig{
685			FunctionCallingConfig: &genai.FunctionCallingConfig{
686				Mode: genai.FunctionCallingConfigModeAuto,
687			},
688		}
689	case ai.ToolChoiceRequired:
690		googleToolChoice = &genai.ToolConfig{
691			FunctionCallingConfig: &genai.FunctionCallingConfig{
692				Mode: genai.FunctionCallingConfigModeAny,
693			},
694		}
695	case ai.ToolChoiceNone:
696		googleToolChoice = &genai.ToolConfig{
697			FunctionCallingConfig: &genai.FunctionCallingConfig{
698				Mode: genai.FunctionCallingConfigModeNone,
699			},
700		}
701	default:
702		googleToolChoice = &genai.ToolConfig{
703			FunctionCallingConfig: &genai.FunctionCallingConfig{
704				Mode: genai.FunctionCallingConfigModeAny,
705				AllowedFunctionNames: []string{
706					string(*toolChoice),
707				},
708			},
709		}
710	}
711	return //nolint: nakedret
712}
713
714func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
715	properties := make(map[string]*genai.Schema)
716
717	for name, param := range parameters {
718		properties[name] = convertToSchema(param)
719	}
720
721	return properties
722}
723
724func convertToSchema(param any) *genai.Schema {
725	schema := &genai.Schema{Type: genai.TypeString}
726
727	paramMap, ok := param.(map[string]any)
728	if !ok {
729		return schema
730	}
731
732	if desc, ok := paramMap["description"].(string); ok {
733		schema.Description = desc
734	}
735
736	typeVal, hasType := paramMap["type"]
737	if !hasType {
738		return schema
739	}
740
741	typeStr, ok := typeVal.(string)
742	if !ok {
743		return schema
744	}
745
746	schema.Type = mapJSONTypeToGoogle(typeStr)
747
748	switch typeStr {
749	case "array":
750		schema.Items = processArrayItems(paramMap)
751	case "object":
752		if props, ok := paramMap["properties"].(map[string]any); ok {
753			schema.Properties = convertSchemaProperties(props)
754		}
755	}
756
757	return schema
758}
759
760func processArrayItems(paramMap map[string]any) *genai.Schema {
761	items, ok := paramMap["items"].(map[string]any)
762	if !ok {
763		return nil
764	}
765
766	return convertToSchema(items)
767}
768
769func mapJSONTypeToGoogle(jsonType string) genai.Type {
770	switch jsonType {
771	case "string":
772		return genai.TypeString
773	case "number":
774		return genai.TypeNumber
775	case "integer":
776		return genai.TypeInteger
777	case "boolean":
778		return genai.TypeBoolean
779	case "array":
780		return genai.TypeArray
781	case "object":
782		return genai.TypeObject
783	default:
784		return genai.TypeString // Default to string for unknown types
785	}
786}
787
788func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
789	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
790		return nil, errors.New("no response from model")
791	}
792
793	var (
794		content      []ai.Content
795		finishReason ai.FinishReason
796		hasToolCalls bool
797		candidate    = response.Candidates[0]
798	)
799
800	for _, part := range candidate.Content.Parts {
801		switch {
802		case part.Text != "":
803			if part.Thought {
804				content = append(content, ai.ReasoningContent{Text: part.Text})
805			} else {
806				content = append(content, ai.TextContent{Text: part.Text})
807			}
808		case part.FunctionCall != nil:
809			input, err := json.Marshal(part.FunctionCall.Args)
810			if err != nil {
811				return nil, err
812			}
813			toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
814			content = append(content, ai.ToolCallContent{
815				ToolCallID:       toolCallID,
816				ToolName:         part.FunctionCall.Name,
817				Input:            string(input),
818				ProviderExecuted: false,
819			})
820			hasToolCalls = true
821		default:
822			// Silently skip unknown part types instead of erroring
823			// This allows for forward compatibility with new part types
824		}
825	}
826
827	if hasToolCalls {
828		finishReason = ai.FinishReasonToolCalls
829	} else {
830		finishReason = mapFinishReason(candidate.FinishReason)
831	}
832
833	return &ai.Response{
834		Content:      content,
835		Usage:        mapUsage(response.UsageMetadata),
836		FinishReason: finishReason,
837		Warnings:     warnings,
838	}, nil
839}
840
841func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
842	switch reason {
843	case genai.FinishReasonStop:
844		return ai.FinishReasonStop
845	case genai.FinishReasonMaxTokens:
846		return ai.FinishReasonLength
847	case genai.FinishReasonSafety,
848		genai.FinishReasonBlocklist,
849		genai.FinishReasonProhibitedContent,
850		genai.FinishReasonSPII,
851		genai.FinishReasonImageSafety:
852		return ai.FinishReasonContentFilter
853	case genai.FinishReasonRecitation,
854		genai.FinishReasonLanguage,
855		genai.FinishReasonMalformedFunctionCall:
856		return ai.FinishReasonError
857	case genai.FinishReasonOther:
858		return ai.FinishReasonOther
859	default:
860		return ai.FinishReasonUnknown
861	}
862}
863
864func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
865	return ai.Usage{
866		InputTokens:         int64(usage.ToolUsePromptTokenCount),
867		OutputTokens:        int64(usage.CandidatesTokenCount),
868		TotalTokens:         int64(usage.TotalTokenCount),
869		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
870		CacheCreationTokens: int64(usage.CachedContentTokenCount),
871		CacheReadTokens:     0,
872	}
873}