1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 "github.com/charmbracelet/x/exp/slice"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type provider struct {
21 options options
22}
23
24type options struct {
25 apiKey string
26 name string
27 headers map[string]string
28 client *http.Client
29}
30
31type Option = func(*options)
32
33func New(opts ...Option) ai.Provider {
34 options := options{
35 headers: map[string]string{},
36 }
37 for _, o := range opts {
38 o(&options)
39 }
40
41 options.name = cmp.Or(options.name, "google")
42
43 return &provider{
44 options: options,
45 }
46}
47
48func WithAPIKey(apiKey string) Option {
49 return func(o *options) {
50 o.apiKey = apiKey
51 }
52}
53
54func WithName(name string) Option {
55 return func(o *options) {
56 o.name = name
57 }
58}
59
60func WithHeaders(headers map[string]string) Option {
61 return func(o *options) {
62 maps.Copy(o.headers, headers)
63 }
64}
65
66func WithHTTPClient(client *http.Client) Option {
67 return func(o *options) {
68 o.client = client
69 }
70}
71
72func (*provider) Name() string {
73 return Name
74}
75
76func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
77 var options ProviderOptions
78 if err := ai.ParseOptions(data, &options); err != nil {
79 return nil, err
80 }
81 return &options, nil
82}
83
84type languageModel struct {
85 provider string
86 modelID string
87 client *genai.Client
88 providerOptions options
89}
90
91// LanguageModel implements ai.Provider.
92func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
93 cc := &genai.ClientConfig{
94 APIKey: g.options.apiKey,
95 Backend: genai.BackendGeminiAPI,
96 HTTPClient: g.options.client,
97 }
98 client, err := genai.NewClient(context.Background(), cc)
99 if err != nil {
100 return nil, err
101 }
102 return &languageModel{
103 modelID: modelID,
104 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
105 providerOptions: g.options,
106 client: client,
107 }, nil
108}
109
110func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
111 config := &genai.GenerateContentConfig{}
112
113 providerOptions := &ProviderOptions{}
114 if v, ok := call.ProviderOptions[Name]; ok {
115 providerOptions, ok = v.(*ProviderOptions)
116 if !ok {
117 return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
118 }
119 }
120
121 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
122
123 if providerOptions.ThinkingConfig != nil &&
124 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
125 *providerOptions.ThinkingConfig.IncludeThoughts &&
126 strings.HasPrefix(a.provider, "google.vertex.") {
127 warnings = append(warnings, ai.CallWarning{
128 Type: ai.CallWarningTypeOther,
129 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
130 "and might not be supported or could behave unexpectedly with the current Google provider " +
131 fmt.Sprintf("(%s)", a.provider),
132 })
133 }
134
135 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
136
137 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
138 if len(content) > 0 && content[0].Role == genai.RoleUser {
139 systemParts := []string{}
140 for _, sp := range systemInstructions.Parts {
141 systemParts = append(systemParts, sp.Text)
142 }
143 systemMsg := strings.Join(systemParts, "\n")
144 content[0].Parts = append([]*genai.Part{
145 {
146 Text: systemMsg + "\n\n",
147 },
148 }, content[0].Parts...)
149 systemInstructions = nil
150 }
151 }
152
153 config.SystemInstruction = systemInstructions
154
155 if call.MaxOutputTokens != nil {
156 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
157 }
158
159 if call.Temperature != nil {
160 tmp := float32(*call.Temperature)
161 config.Temperature = &tmp
162 }
163 if call.TopK != nil {
164 tmp := float32(*call.TopK)
165 config.TopK = &tmp
166 }
167 if call.TopP != nil {
168 tmp := float32(*call.TopP)
169 config.TopP = &tmp
170 }
171 if call.FrequencyPenalty != nil {
172 tmp := float32(*call.FrequencyPenalty)
173 config.FrequencyPenalty = &tmp
174 }
175 if call.PresencePenalty != nil {
176 tmp := float32(*call.PresencePenalty)
177 config.PresencePenalty = &tmp
178 }
179
180 if providerOptions.ThinkingConfig != nil {
181 config.ThinkingConfig = &genai.ThinkingConfig{}
182 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
183 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
184 }
185 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
186 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
187 config.ThinkingConfig.ThinkingBudget = &tmp
188 }
189 }
190 for _, safetySetting := range providerOptions.SafetySettings {
191 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
192 Category: genai.HarmCategory(safetySetting.Category),
193 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
194 })
195 }
196 if providerOptions.CachedContent != "" {
197 config.CachedContent = providerOptions.CachedContent
198 }
199
200 if len(call.Tools) > 0 {
201 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
202 config.ToolConfig = toolChoice
203 config.Tools = append(config.Tools, &genai.Tool{
204 FunctionDeclarations: tools,
205 })
206 warnings = append(warnings, toolWarnings...)
207 }
208
209 return config, content, warnings, nil
210}
211
212func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
213 var systemInstructions *genai.Content
214 var content []*genai.Content
215 var warnings []ai.CallWarning
216
217 finishedSystemBlock := false
218 for _, msg := range prompt {
219 switch msg.Role {
220 case ai.MessageRoleSystem:
221 if finishedSystemBlock {
222 // skip multiple system messages that are separated by user/assistant messages
223 // TODO: see if we need to send error here?
224 continue
225 }
226 finishedSystemBlock = true
227
228 var systemMessages []string
229 for _, part := range msg.Content {
230 text, ok := ai.AsMessagePart[ai.TextPart](part)
231 if !ok || text.Text == "" {
232 continue
233 }
234 systemMessages = append(systemMessages, text.Text)
235 }
236 if len(systemMessages) > 0 {
237 systemInstructions = &genai.Content{
238 Parts: []*genai.Part{
239 {
240 Text: strings.Join(systemMessages, "\n"),
241 },
242 },
243 }
244 }
245 case ai.MessageRoleUser:
246 var parts []*genai.Part
247 for _, part := range msg.Content {
248 switch part.GetType() {
249 case ai.ContentTypeText:
250 text, ok := ai.AsMessagePart[ai.TextPart](part)
251 if !ok || text.Text == "" {
252 continue
253 }
254 parts = append(parts, &genai.Part{
255 Text: text.Text,
256 })
257 case ai.ContentTypeFile:
258 file, ok := ai.AsMessagePart[ai.FilePart](part)
259 if !ok {
260 continue
261 }
262 var encoded []byte
263 base64.StdEncoding.Encode(encoded, file.Data)
264 parts = append(parts, &genai.Part{
265 InlineData: &genai.Blob{
266 Data: encoded,
267 MIMEType: file.MediaType,
268 },
269 })
270 }
271 }
272 if len(parts) > 0 {
273 content = append(content, &genai.Content{
274 Role: genai.RoleUser,
275 Parts: parts,
276 })
277 }
278 case ai.MessageRoleAssistant:
279 var parts []*genai.Part
280 for _, part := range msg.Content {
281 switch part.GetType() {
282 case ai.ContentTypeText:
283 text, ok := ai.AsMessagePart[ai.TextPart](part)
284 if !ok || text.Text == "" {
285 continue
286 }
287 parts = append(parts, &genai.Part{
288 Text: text.Text,
289 })
290 case ai.ContentTypeToolCall:
291 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
292 if !ok {
293 continue
294 }
295
296 var result map[string]any
297 err := json.Unmarshal([]byte(toolCall.Input), &result)
298 if err != nil {
299 continue
300 }
301 parts = append(parts, &genai.Part{
302 FunctionCall: &genai.FunctionCall{
303 ID: toolCall.ToolCallID,
304 Name: toolCall.ToolName,
305 Args: result,
306 },
307 })
308 }
309 }
310 if len(parts) > 0 {
311 content = append(content, &genai.Content{
312 Role: genai.RoleModel,
313 Parts: parts,
314 })
315 }
316 case ai.MessageRoleTool:
317 var parts []*genai.Part
318 for _, part := range msg.Content {
319 switch part.GetType() {
320 case ai.ContentTypeToolResult:
321 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
322 if !ok {
323 continue
324 }
325 var toolCall ai.ToolCallPart
326 for _, m := range prompt {
327 if m.Role == ai.MessageRoleAssistant {
328 for _, content := range m.Content {
329 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
330 if !ok {
331 continue
332 }
333 if tc.ToolCallID == result.ToolCallID {
334 toolCall = tc
335 break
336 }
337 }
338 }
339 }
340 switch result.Output.GetType() {
341 case ai.ToolResultContentTypeText:
342 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
343 if !ok {
344 continue
345 }
346 response := map[string]any{"result": content.Text}
347 parts = append(parts, &genai.Part{
348 FunctionResponse: &genai.FunctionResponse{
349 ID: result.ToolCallID,
350 Response: response,
351 Name: toolCall.ToolName,
352 },
353 })
354
355 case ai.ToolResultContentTypeError:
356 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
357 if !ok {
358 continue
359 }
360 response := map[string]any{"result": content.Error.Error()}
361 parts = append(parts, &genai.Part{
362 FunctionResponse: &genai.FunctionResponse{
363 ID: result.ToolCallID,
364 Response: response,
365 Name: toolCall.ToolName,
366 },
367 })
368 }
369 }
370 }
371 if len(parts) > 0 {
372 content = append(content, &genai.Content{
373 Role: genai.RoleUser,
374 Parts: parts,
375 })
376 }
377 }
378 }
379 return systemInstructions, content, warnings
380}
381
382// Generate implements ai.LanguageModel.
383func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
384 config, contents, warnings, err := g.prepareParams(call)
385 if err != nil {
386 return nil, err
387 }
388
389 lastMessage, history, ok := slice.Pop(contents)
390 if !ok {
391 return nil, errors.New("no messages to send")
392 }
393
394 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
395 if err != nil {
396 return nil, err
397 }
398
399 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
400 if err != nil {
401 return nil, err
402 }
403
404 return mapResponse(response, warnings)
405}
406
407// Model implements ai.LanguageModel.
408func (g *languageModel) Model() string {
409 return g.modelID
410}
411
412// Provider implements ai.LanguageModel.
413func (g *languageModel) Provider() string {
414 return g.provider
415}
416
417// Stream implements ai.LanguageModel.
418func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
419 config, contents, warnings, err := g.prepareParams(call)
420 if err != nil {
421 return nil, err
422 }
423
424 lastMessage, history, ok := slice.Pop(contents)
425 if !ok {
426 return nil, errors.New("no messages to send")
427 }
428
429 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
430 if err != nil {
431 return nil, err
432 }
433
434 return func(yield func(ai.StreamPart) bool) {
435 if len(warnings) > 0 {
436 if !yield(ai.StreamPart{
437 Type: ai.StreamPartTypeWarnings,
438 Warnings: warnings,
439 }) {
440 return
441 }
442 }
443
444 var currentContent string
445 var toolCalls []ai.ToolCallContent
446 var isActiveText bool
447 var isActiveReasoning bool
448 var blockCounter int
449 var currentTextBlockID string
450 var currentReasoningBlockID string
451 var usage ai.Usage
452 var lastFinishReason ai.FinishReason
453
454 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
455 if err != nil {
456 yield(ai.StreamPart{
457 Type: ai.StreamPartTypeError,
458 Error: err,
459 })
460 return
461 }
462
463 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
464 for _, part := range resp.Candidates[0].Content.Parts {
465 switch {
466 case part.Text != "":
467 delta := part.Text
468 if delta != "" {
469 // Check if this is a reasoning/thought part
470 if part.Thought {
471 // End any active text block before starting reasoning
472 if isActiveText {
473 isActiveText = false
474 if !yield(ai.StreamPart{
475 Type: ai.StreamPartTypeTextEnd,
476 ID: currentTextBlockID,
477 }) {
478 return
479 }
480 }
481
482 // Start new reasoning block if not already active
483 if !isActiveReasoning {
484 isActiveReasoning = true
485 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
486 blockCounter++
487 if !yield(ai.StreamPart{
488 Type: ai.StreamPartTypeReasoningStart,
489 ID: currentReasoningBlockID,
490 }) {
491 return
492 }
493 }
494
495 if !yield(ai.StreamPart{
496 Type: ai.StreamPartTypeReasoningDelta,
497 ID: currentReasoningBlockID,
498 Delta: delta,
499 }) {
500 return
501 }
502 } else {
503 // Regular text part
504 // End any active reasoning block before starting text
505 if isActiveReasoning {
506 isActiveReasoning = false
507 if !yield(ai.StreamPart{
508 Type: ai.StreamPartTypeReasoningEnd,
509 ID: currentReasoningBlockID,
510 }) {
511 return
512 }
513 }
514
515 // Start new text block if not already active
516 if !isActiveText {
517 isActiveText = true
518 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
519 blockCounter++
520 if !yield(ai.StreamPart{
521 Type: ai.StreamPartTypeTextStart,
522 ID: currentTextBlockID,
523 }) {
524 return
525 }
526 }
527
528 if !yield(ai.StreamPart{
529 Type: ai.StreamPartTypeTextDelta,
530 ID: currentTextBlockID,
531 Delta: delta,
532 }) {
533 return
534 }
535 currentContent += delta
536 }
537 }
538 case part.FunctionCall != nil:
539 // End any active text or reasoning blocks
540 if isActiveText {
541 isActiveText = false
542 if !yield(ai.StreamPart{
543 Type: ai.StreamPartTypeTextEnd,
544 ID: currentTextBlockID,
545 }) {
546 return
547 }
548 }
549 if isActiveReasoning {
550 isActiveReasoning = false
551 if !yield(ai.StreamPart{
552 Type: ai.StreamPartTypeReasoningEnd,
553 ID: currentReasoningBlockID,
554 }) {
555 return
556 }
557 }
558
559 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
560
561 args, err := json.Marshal(part.FunctionCall.Args)
562 if err != nil {
563 yield(ai.StreamPart{
564 Type: ai.StreamPartTypeError,
565 Error: err,
566 })
567 return
568 }
569
570 if !yield(ai.StreamPart{
571 Type: ai.StreamPartTypeToolInputStart,
572 ID: toolCallID,
573 ToolCallName: part.FunctionCall.Name,
574 }) {
575 return
576 }
577
578 if !yield(ai.StreamPart{
579 Type: ai.StreamPartTypeToolInputDelta,
580 ID: toolCallID,
581 Delta: string(args),
582 }) {
583 return
584 }
585
586 if !yield(ai.StreamPart{
587 Type: ai.StreamPartTypeToolInputEnd,
588 ID: toolCallID,
589 }) {
590 return
591 }
592
593 if !yield(ai.StreamPart{
594 Type: ai.StreamPartTypeToolCall,
595 ID: toolCallID,
596 ToolCallName: part.FunctionCall.Name,
597 ToolCallInput: string(args),
598 ProviderExecuted: false,
599 }) {
600 return
601 }
602
603 toolCalls = append(toolCalls, ai.ToolCallContent{
604 ToolCallID: toolCallID,
605 ToolName: part.FunctionCall.Name,
606 Input: string(args),
607 ProviderExecuted: false,
608 })
609 }
610 }
611 }
612
613 if resp.UsageMetadata != nil {
614 usage = mapUsage(resp.UsageMetadata)
615 }
616
617 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
618 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
619 }
620 }
621
622 // Close any open blocks before finishing
623 if isActiveText {
624 if !yield(ai.StreamPart{
625 Type: ai.StreamPartTypeTextEnd,
626 ID: currentTextBlockID,
627 }) {
628 return
629 }
630 }
631 if isActiveReasoning {
632 if !yield(ai.StreamPart{
633 Type: ai.StreamPartTypeReasoningEnd,
634 ID: currentReasoningBlockID,
635 }) {
636 return
637 }
638 }
639
640 finishReason := lastFinishReason
641 if len(toolCalls) > 0 {
642 finishReason = ai.FinishReasonToolCalls
643 } else if finishReason == "" {
644 finishReason = ai.FinishReasonStop
645 }
646
647 yield(ai.StreamPart{
648 Type: ai.StreamPartTypeFinish,
649 Usage: usage,
650 FinishReason: finishReason,
651 })
652 }, nil
653}
654
655func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
656 for _, tool := range tools {
657 if tool.GetType() == ai.ToolTypeFunction {
658 ft, ok := tool.(ai.FunctionTool)
659 if !ok {
660 continue
661 }
662
663 required := []string{}
664 var properties map[string]any
665 if props, ok := ft.InputSchema["properties"]; ok {
666 properties, _ = props.(map[string]any)
667 }
668 if req, ok := ft.InputSchema["required"]; ok {
669 if reqArr, ok := req.([]string); ok {
670 required = reqArr
671 }
672 }
673 declaration := &genai.FunctionDeclaration{
674 Name: ft.Name,
675 Description: ft.Description,
676 Parameters: &genai.Schema{
677 Type: genai.TypeObject,
678 Properties: convertSchemaProperties(properties),
679 Required: required,
680 },
681 }
682 googleTools = append(googleTools, declaration)
683 continue
684 }
685 // TODO: handle provider tool calls
686 warnings = append(warnings, ai.CallWarning{
687 Type: ai.CallWarningTypeUnsupportedTool,
688 Tool: tool,
689 Message: "tool is not supported",
690 })
691 }
692 if toolChoice == nil {
693 return //nolint: nakedret
694 }
695 switch *toolChoice {
696 case ai.ToolChoiceAuto:
697 googleToolChoice = &genai.ToolConfig{
698 FunctionCallingConfig: &genai.FunctionCallingConfig{
699 Mode: genai.FunctionCallingConfigModeAuto,
700 },
701 }
702 case ai.ToolChoiceRequired:
703 googleToolChoice = &genai.ToolConfig{
704 FunctionCallingConfig: &genai.FunctionCallingConfig{
705 Mode: genai.FunctionCallingConfigModeAny,
706 },
707 }
708 case ai.ToolChoiceNone:
709 googleToolChoice = &genai.ToolConfig{
710 FunctionCallingConfig: &genai.FunctionCallingConfig{
711 Mode: genai.FunctionCallingConfigModeNone,
712 },
713 }
714 default:
715 googleToolChoice = &genai.ToolConfig{
716 FunctionCallingConfig: &genai.FunctionCallingConfig{
717 Mode: genai.FunctionCallingConfigModeAny,
718 AllowedFunctionNames: []string{
719 string(*toolChoice),
720 },
721 },
722 }
723 }
724 return //nolint: nakedret
725}
726
727func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
728 properties := make(map[string]*genai.Schema)
729
730 for name, param := range parameters {
731 properties[name] = convertToSchema(param)
732 }
733
734 return properties
735}
736
737func convertToSchema(param any) *genai.Schema {
738 schema := &genai.Schema{Type: genai.TypeString}
739
740 paramMap, ok := param.(map[string]any)
741 if !ok {
742 return schema
743 }
744
745 if desc, ok := paramMap["description"].(string); ok {
746 schema.Description = desc
747 }
748
749 typeVal, hasType := paramMap["type"]
750 if !hasType {
751 return schema
752 }
753
754 typeStr, ok := typeVal.(string)
755 if !ok {
756 return schema
757 }
758
759 schema.Type = mapJSONTypeToGoogle(typeStr)
760
761 switch typeStr {
762 case "array":
763 schema.Items = processArrayItems(paramMap)
764 case "object":
765 if props, ok := paramMap["properties"].(map[string]any); ok {
766 schema.Properties = convertSchemaProperties(props)
767 }
768 }
769
770 return schema
771}
772
773func processArrayItems(paramMap map[string]any) *genai.Schema {
774 items, ok := paramMap["items"].(map[string]any)
775 if !ok {
776 return nil
777 }
778
779 return convertToSchema(items)
780}
781
782func mapJSONTypeToGoogle(jsonType string) genai.Type {
783 switch jsonType {
784 case "string":
785 return genai.TypeString
786 case "number":
787 return genai.TypeNumber
788 case "integer":
789 return genai.TypeInteger
790 case "boolean":
791 return genai.TypeBoolean
792 case "array":
793 return genai.TypeArray
794 case "object":
795 return genai.TypeObject
796 default:
797 return genai.TypeString // Default to string for unknown types
798 }
799}
800
801func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
802 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
803 return nil, errors.New("no response from model")
804 }
805
806 var (
807 content []ai.Content
808 finishReason ai.FinishReason
809 hasToolCalls bool
810 candidate = response.Candidates[0]
811 )
812
813 for _, part := range candidate.Content.Parts {
814 switch {
815 case part.Text != "":
816 if part.Thought {
817 content = append(content, ai.ReasoningContent{Text: part.Text})
818 } else {
819 content = append(content, ai.TextContent{Text: part.Text})
820 }
821 case part.FunctionCall != nil:
822 input, err := json.Marshal(part.FunctionCall.Args)
823 if err != nil {
824 return nil, err
825 }
826 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
827 content = append(content, ai.ToolCallContent{
828 ToolCallID: toolCallID,
829 ToolName: part.FunctionCall.Name,
830 Input: string(input),
831 ProviderExecuted: false,
832 })
833 hasToolCalls = true
834 default:
835 // Silently skip unknown part types instead of erroring
836 // This allows for forward compatibility with new part types
837 }
838 }
839
840 if hasToolCalls {
841 finishReason = ai.FinishReasonToolCalls
842 } else {
843 finishReason = mapFinishReason(candidate.FinishReason)
844 }
845
846 return &ai.Response{
847 Content: content,
848 Usage: mapUsage(response.UsageMetadata),
849 FinishReason: finishReason,
850 Warnings: warnings,
851 }, nil
852}
853
854func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
855 switch reason {
856 case genai.FinishReasonStop:
857 return ai.FinishReasonStop
858 case genai.FinishReasonMaxTokens:
859 return ai.FinishReasonLength
860 case genai.FinishReasonSafety,
861 genai.FinishReasonBlocklist,
862 genai.FinishReasonProhibitedContent,
863 genai.FinishReasonSPII,
864 genai.FinishReasonImageSafety:
865 return ai.FinishReasonContentFilter
866 case genai.FinishReasonRecitation,
867 genai.FinishReasonLanguage,
868 genai.FinishReasonMalformedFunctionCall:
869 return ai.FinishReasonError
870 case genai.FinishReasonOther:
871 return ai.FinishReasonOther
872 default:
873 return ai.FinishReasonUnknown
874 }
875}
876
877func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
878 return ai.Usage{
879 InputTokens: int64(usage.ToolUsePromptTokenCount),
880 OutputTokens: int64(usage.CandidatesTokenCount),
881 TotalTokens: int64(usage.TotalTokenCount),
882 ReasoningTokens: int64(usage.ThoughtsTokenCount),
883 CacheCreationTokens: int64(usage.CachedContentTokenCount),
884 CacheReadTokens: 0,
885 }
886}