1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 "github.com/charmbracelet/x/exp/slice"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type provider struct {
21 options options
22}
23
24type options struct {
25 apiKey string
26 name string
27 headers map[string]string
28 client *http.Client
29}
30
31type Option = func(*options)
32
33func New(opts ...Option) ai.Provider {
34 options := options{
35 headers: map[string]string{},
36 }
37 for _, o := range opts {
38 o(&options)
39 }
40
41 options.name = cmp.Or(options.name, "google")
42
43 return &provider{
44 options: options,
45 }
46}
47
48func WithAPIKey(apiKey string) Option {
49 return func(o *options) {
50 o.apiKey = apiKey
51 }
52}
53
54func WithName(name string) Option {
55 return func(o *options) {
56 o.name = name
57 }
58}
59
60func WithHeaders(headers map[string]string) Option {
61 return func(o *options) {
62 maps.Copy(o.headers, headers)
63 }
64}
65
66func WithHTTPClient(client *http.Client) Option {
67 return func(o *options) {
68 o.client = client
69 }
70}
71
72type languageModel struct {
73 provider string
74 modelID string
75 client *genai.Client
76 providerOptions options
77}
78
79// LanguageModel implements ai.Provider.
80func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
81 cc := &genai.ClientConfig{
82 APIKey: g.options.apiKey,
83 Backend: genai.BackendGeminiAPI,
84 HTTPClient: g.options.client,
85 }
86 client, err := genai.NewClient(context.Background(), cc)
87 if err != nil {
88 return nil, err
89 }
90 return &languageModel{
91 modelID: modelID,
92 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
93 providerOptions: g.options,
94 client: client,
95 }, nil
96}
97
98func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
99 config := &genai.GenerateContentConfig{}
100 providerOptions := &providerOptions{}
101 if v, ok := call.ProviderOptions["google"]; ok {
102 err := ai.ParseOptions(v, providerOptions)
103 if err != nil {
104 return nil, nil, nil, err
105 }
106 }
107
108 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
109
110 if providerOptions.ThinkingConfig != nil &&
111 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
112 *providerOptions.ThinkingConfig.IncludeThoughts &&
113 strings.HasPrefix(a.provider, "google.vertex.") {
114 warnings = append(warnings, ai.CallWarning{
115 Type: ai.CallWarningTypeOther,
116 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
117 "and might not be supported or could behave unexpectedly with the current Google provider " +
118 fmt.Sprintf("(%s)", a.provider),
119 })
120 }
121
122 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
123
124 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
125 if len(content) > 0 && content[0].Role == genai.RoleUser {
126 systemParts := []string{}
127 for _, sp := range systemInstructions.Parts {
128 systemParts = append(systemParts, sp.Text)
129 }
130 systemMsg := strings.Join(systemParts, "\n")
131 content[0].Parts = append([]*genai.Part{
132 {
133 Text: systemMsg + "\n\n",
134 },
135 }, content[0].Parts...)
136 systemInstructions = nil
137 }
138 }
139
140 config.SystemInstruction = systemInstructions
141
142 if call.MaxOutputTokens != nil {
143 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
144 }
145
146 if call.Temperature != nil {
147 tmp := float32(*call.Temperature)
148 config.Temperature = &tmp
149 }
150 if call.TopK != nil {
151 tmp := float32(*call.TopK)
152 config.TopK = &tmp
153 }
154 if call.TopP != nil {
155 tmp := float32(*call.TopP)
156 config.TopP = &tmp
157 }
158 if call.FrequencyPenalty != nil {
159 tmp := float32(*call.FrequencyPenalty)
160 config.FrequencyPenalty = &tmp
161 }
162 if call.PresencePenalty != nil {
163 tmp := float32(*call.PresencePenalty)
164 config.PresencePenalty = &tmp
165 }
166
167 if providerOptions.ThinkingConfig != nil {
168 config.ThinkingConfig = &genai.ThinkingConfig{}
169 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
170 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
171 }
172 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
173 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
174 config.ThinkingConfig.ThinkingBudget = &tmp
175 }
176 }
177 for _, safetySetting := range providerOptions.SafetySettings {
178 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
179 Category: genai.HarmCategory(safetySetting.Category),
180 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
181 })
182 }
183 if providerOptions.CachedContent != "" {
184 config.CachedContent = providerOptions.CachedContent
185 }
186
187 if len(call.Tools) > 0 {
188 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
189 config.ToolConfig = toolChoice
190 config.Tools = append(config.Tools, &genai.Tool{
191 FunctionDeclarations: tools,
192 })
193 warnings = append(warnings, toolWarnings...)
194 }
195
196 return config, content, warnings, nil
197}
198
199func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
200 var systemInstructions *genai.Content
201 var content []*genai.Content
202 var warnings []ai.CallWarning
203
204 finishedSystemBlock := false
205 for _, msg := range prompt {
206 switch msg.Role {
207 case ai.MessageRoleSystem:
208 if finishedSystemBlock {
209 // skip multiple system messages that are separated by user/assistant messages
210 // TODO: see if we need to send error here?
211 continue
212 }
213 finishedSystemBlock = true
214
215 var systemMessages []string
216 for _, part := range msg.Content {
217 text, ok := ai.AsMessagePart[ai.TextPart](part)
218 if !ok || text.Text == "" {
219 continue
220 }
221 systemMessages = append(systemMessages, text.Text)
222 }
223 if len(systemMessages) > 0 {
224 systemInstructions = &genai.Content{
225 Parts: []*genai.Part{
226 {
227 Text: strings.Join(systemMessages, "\n"),
228 },
229 },
230 }
231 }
232 case ai.MessageRoleUser:
233 var parts []*genai.Part
234 for _, part := range msg.Content {
235 switch part.GetType() {
236 case ai.ContentTypeText:
237 text, ok := ai.AsMessagePart[ai.TextPart](part)
238 if !ok || text.Text == "" {
239 continue
240 }
241 parts = append(parts, &genai.Part{
242 Text: text.Text,
243 })
244 case ai.ContentTypeFile:
245 file, ok := ai.AsMessagePart[ai.FilePart](part)
246 if !ok {
247 continue
248 }
249 var encoded []byte
250 base64.StdEncoding.Encode(encoded, file.Data)
251 parts = append(parts, &genai.Part{
252 InlineData: &genai.Blob{
253 Data: encoded,
254 MIMEType: file.MediaType,
255 },
256 })
257 }
258 }
259 if len(parts) > 0 {
260 content = append(content, &genai.Content{
261 Role: genai.RoleUser,
262 Parts: parts,
263 })
264 }
265 case ai.MessageRoleAssistant:
266 var parts []*genai.Part
267 for _, part := range msg.Content {
268 switch part.GetType() {
269 case ai.ContentTypeText:
270 text, ok := ai.AsMessagePart[ai.TextPart](part)
271 if !ok || text.Text == "" {
272 continue
273 }
274 parts = append(parts, &genai.Part{
275 Text: text.Text,
276 })
277 case ai.ContentTypeToolCall:
278 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
279 if !ok {
280 continue
281 }
282
283 var result map[string]any
284 err := json.Unmarshal([]byte(toolCall.Input), &result)
285 if err != nil {
286 continue
287 }
288 parts = append(parts, &genai.Part{
289 FunctionCall: &genai.FunctionCall{
290 ID: toolCall.ToolCallID,
291 Name: toolCall.ToolName,
292 Args: result,
293 },
294 })
295 }
296 }
297 if len(parts) > 0 {
298 content = append(content, &genai.Content{
299 Role: genai.RoleModel,
300 Parts: parts,
301 })
302 }
303 case ai.MessageRoleTool:
304 var parts []*genai.Part
305 for _, part := range msg.Content {
306 switch part.GetType() {
307 case ai.ContentTypeToolResult:
308 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
309 if !ok {
310 continue
311 }
312 var toolCall ai.ToolCallPart
313 for _, m := range prompt {
314 if m.Role == ai.MessageRoleAssistant {
315 for _, content := range m.Content {
316 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
317 if !ok {
318 continue
319 }
320 if tc.ToolCallID == result.ToolCallID {
321 toolCall = tc
322 break
323 }
324 }
325 }
326 }
327 switch result.Output.GetType() {
328 case ai.ToolResultContentTypeText:
329 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
330 if !ok {
331 continue
332 }
333 response := map[string]any{"result": content.Text}
334 parts = append(parts, &genai.Part{
335 FunctionResponse: &genai.FunctionResponse{
336 ID: result.ToolCallID,
337 Response: response,
338 Name: toolCall.ToolName,
339 },
340 })
341
342 case ai.ToolResultContentTypeError:
343 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
344 if !ok {
345 continue
346 }
347 response := map[string]any{"result": content.Error.Error()}
348 parts = append(parts, &genai.Part{
349 FunctionResponse: &genai.FunctionResponse{
350 ID: result.ToolCallID,
351 Response: response,
352 Name: toolCall.ToolName,
353 },
354 })
355 }
356 }
357 }
358 if len(parts) > 0 {
359 content = append(content, &genai.Content{
360 Role: genai.RoleUser,
361 Parts: parts,
362 })
363 }
364 }
365 }
366 return systemInstructions, content, warnings
367}
368
369// Generate implements ai.LanguageModel.
370func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
371 config, contents, warnings, err := g.prepareParams(call)
372 if err != nil {
373 return nil, err
374 }
375
376 lastMessage, history, ok := slice.Pop(contents)
377 if !ok {
378 return nil, errors.New("no messages to send")
379 }
380
381 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
382 if err != nil {
383 return nil, err
384 }
385
386 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
387 if err != nil {
388 return nil, err
389 }
390
391 return mapResponse(response, warnings)
392}
393
394// Model implements ai.LanguageModel.
395func (g *languageModel) Model() string {
396 return g.modelID
397}
398
399// Provider implements ai.LanguageModel.
400func (g *languageModel) Provider() string {
401 return g.provider
402}
403
404// Stream implements ai.LanguageModel.
405func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
406 config, contents, warnings, err := g.prepareParams(call)
407 if err != nil {
408 return nil, err
409 }
410
411 lastMessage, history, ok := slice.Pop(contents)
412 if !ok {
413 return nil, errors.New("no messages to send")
414 }
415
416 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
417 if err != nil {
418 return nil, err
419 }
420
421 return func(yield func(ai.StreamPart) bool) {
422 if len(warnings) > 0 {
423 if !yield(ai.StreamPart{
424 Type: ai.StreamPartTypeWarnings,
425 Warnings: warnings,
426 }) {
427 return
428 }
429 }
430
431 var currentContent string
432 var toolCalls []ai.ToolCallContent
433 var isActiveText bool
434 var isActiveReasoning bool
435 var blockCounter int
436 var currentTextBlockID string
437 var currentReasoningBlockID string
438 var usage ai.Usage
439 var lastFinishReason ai.FinishReason
440
441 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
442 if err != nil {
443 yield(ai.StreamPart{
444 Type: ai.StreamPartTypeError,
445 Error: err,
446 })
447 return
448 }
449
450 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
451 for _, part := range resp.Candidates[0].Content.Parts {
452 switch {
453 case part.Text != "":
454 delta := part.Text
455 if delta != "" {
456 // Check if this is a reasoning/thought part
457 if part.Thought {
458 // End any active text block before starting reasoning
459 if isActiveText {
460 isActiveText = false
461 if !yield(ai.StreamPart{
462 Type: ai.StreamPartTypeTextEnd,
463 ID: currentTextBlockID,
464 }) {
465 return
466 }
467 }
468
469 // Start new reasoning block if not already active
470 if !isActiveReasoning {
471 isActiveReasoning = true
472 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
473 blockCounter++
474 if !yield(ai.StreamPart{
475 Type: ai.StreamPartTypeReasoningStart,
476 ID: currentReasoningBlockID,
477 }) {
478 return
479 }
480 }
481
482 if !yield(ai.StreamPart{
483 Type: ai.StreamPartTypeReasoningDelta,
484 ID: currentReasoningBlockID,
485 Delta: delta,
486 }) {
487 return
488 }
489 } else {
490 // Regular text part
491 // End any active reasoning block before starting text
492 if isActiveReasoning {
493 isActiveReasoning = false
494 if !yield(ai.StreamPart{
495 Type: ai.StreamPartTypeReasoningEnd,
496 ID: currentReasoningBlockID,
497 }) {
498 return
499 }
500 }
501
502 // Start new text block if not already active
503 if !isActiveText {
504 isActiveText = true
505 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
506 blockCounter++
507 if !yield(ai.StreamPart{
508 Type: ai.StreamPartTypeTextStart,
509 ID: currentTextBlockID,
510 }) {
511 return
512 }
513 }
514
515 if !yield(ai.StreamPart{
516 Type: ai.StreamPartTypeTextDelta,
517 ID: currentTextBlockID,
518 Delta: delta,
519 }) {
520 return
521 }
522 currentContent += delta
523 }
524 }
525 case part.FunctionCall != nil:
526 // End any active text or reasoning blocks
527 if isActiveText {
528 isActiveText = false
529 if !yield(ai.StreamPart{
530 Type: ai.StreamPartTypeTextEnd,
531 ID: currentTextBlockID,
532 }) {
533 return
534 }
535 }
536 if isActiveReasoning {
537 isActiveReasoning = false
538 if !yield(ai.StreamPart{
539 Type: ai.StreamPartTypeReasoningEnd,
540 ID: currentReasoningBlockID,
541 }) {
542 return
543 }
544 }
545
546 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
547
548 args, err := json.Marshal(part.FunctionCall.Args)
549 if err != nil {
550 yield(ai.StreamPart{
551 Type: ai.StreamPartTypeError,
552 Error: err,
553 })
554 return
555 }
556
557 if !yield(ai.StreamPart{
558 Type: ai.StreamPartTypeToolInputStart,
559 ID: toolCallID,
560 ToolCallName: part.FunctionCall.Name,
561 }) {
562 return
563 }
564
565 if !yield(ai.StreamPart{
566 Type: ai.StreamPartTypeToolInputDelta,
567 ID: toolCallID,
568 Delta: string(args),
569 }) {
570 return
571 }
572
573 if !yield(ai.StreamPart{
574 Type: ai.StreamPartTypeToolInputEnd,
575 ID: toolCallID,
576 }) {
577 return
578 }
579
580 if !yield(ai.StreamPart{
581 Type: ai.StreamPartTypeToolCall,
582 ID: toolCallID,
583 ToolCallName: part.FunctionCall.Name,
584 ToolCallInput: string(args),
585 ProviderExecuted: false,
586 }) {
587 return
588 }
589
590 toolCalls = append(toolCalls, ai.ToolCallContent{
591 ToolCallID: toolCallID,
592 ToolName: part.FunctionCall.Name,
593 Input: string(args),
594 ProviderExecuted: false,
595 })
596 }
597 }
598 }
599
600 if resp.UsageMetadata != nil {
601 usage = mapUsage(resp.UsageMetadata)
602 }
603
604 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
605 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
606 }
607 }
608
609 // Close any open blocks before finishing
610 if isActiveText {
611 if !yield(ai.StreamPart{
612 Type: ai.StreamPartTypeTextEnd,
613 ID: currentTextBlockID,
614 }) {
615 return
616 }
617 }
618 if isActiveReasoning {
619 if !yield(ai.StreamPart{
620 Type: ai.StreamPartTypeReasoningEnd,
621 ID: currentReasoningBlockID,
622 }) {
623 return
624 }
625 }
626
627 finishReason := lastFinishReason
628 if len(toolCalls) > 0 {
629 finishReason = ai.FinishReasonToolCalls
630 } else if finishReason == "" {
631 finishReason = ai.FinishReasonStop
632 }
633
634 yield(ai.StreamPart{
635 Type: ai.StreamPartTypeFinish,
636 Usage: usage,
637 FinishReason: finishReason,
638 })
639 }, nil
640}
641
642func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
643 for _, tool := range tools {
644 if tool.GetType() == ai.ToolTypeFunction {
645 ft, ok := tool.(ai.FunctionTool)
646 if !ok {
647 continue
648 }
649
650 required := []string{}
651 var properties map[string]any
652 if props, ok := ft.InputSchema["properties"]; ok {
653 properties, _ = props.(map[string]any)
654 }
655 if req, ok := ft.InputSchema["required"]; ok {
656 if reqArr, ok := req.([]string); ok {
657 required = reqArr
658 }
659 }
660 declaration := &genai.FunctionDeclaration{
661 Name: ft.Name,
662 Description: ft.Description,
663 Parameters: &genai.Schema{
664 Type: genai.TypeObject,
665 Properties: convertSchemaProperties(properties),
666 Required: required,
667 },
668 }
669 googleTools = append(googleTools, declaration)
670 continue
671 }
672 // TODO: handle provider tool calls
673 warnings = append(warnings, ai.CallWarning{
674 Type: ai.CallWarningTypeUnsupportedTool,
675 Tool: tool,
676 Message: "tool is not supported",
677 })
678 }
679 if toolChoice == nil {
680 return //nolint: nakedret
681 }
682 switch *toolChoice {
683 case ai.ToolChoiceAuto:
684 googleToolChoice = &genai.ToolConfig{
685 FunctionCallingConfig: &genai.FunctionCallingConfig{
686 Mode: genai.FunctionCallingConfigModeAuto,
687 },
688 }
689 case ai.ToolChoiceRequired:
690 googleToolChoice = &genai.ToolConfig{
691 FunctionCallingConfig: &genai.FunctionCallingConfig{
692 Mode: genai.FunctionCallingConfigModeAny,
693 },
694 }
695 case ai.ToolChoiceNone:
696 googleToolChoice = &genai.ToolConfig{
697 FunctionCallingConfig: &genai.FunctionCallingConfig{
698 Mode: genai.FunctionCallingConfigModeNone,
699 },
700 }
701 default:
702 googleToolChoice = &genai.ToolConfig{
703 FunctionCallingConfig: &genai.FunctionCallingConfig{
704 Mode: genai.FunctionCallingConfigModeAny,
705 AllowedFunctionNames: []string{
706 string(*toolChoice),
707 },
708 },
709 }
710 }
711 return //nolint: nakedret
712}
713
714func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
715 properties := make(map[string]*genai.Schema)
716
717 for name, param := range parameters {
718 properties[name] = convertToSchema(param)
719 }
720
721 return properties
722}
723
724func convertToSchema(param any) *genai.Schema {
725 schema := &genai.Schema{Type: genai.TypeString}
726
727 paramMap, ok := param.(map[string]any)
728 if !ok {
729 return schema
730 }
731
732 if desc, ok := paramMap["description"].(string); ok {
733 schema.Description = desc
734 }
735
736 typeVal, hasType := paramMap["type"]
737 if !hasType {
738 return schema
739 }
740
741 typeStr, ok := typeVal.(string)
742 if !ok {
743 return schema
744 }
745
746 schema.Type = mapJSONTypeToGoogle(typeStr)
747
748 switch typeStr {
749 case "array":
750 schema.Items = processArrayItems(paramMap)
751 case "object":
752 if props, ok := paramMap["properties"].(map[string]any); ok {
753 schema.Properties = convertSchemaProperties(props)
754 }
755 }
756
757 return schema
758}
759
760func processArrayItems(paramMap map[string]any) *genai.Schema {
761 items, ok := paramMap["items"].(map[string]any)
762 if !ok {
763 return nil
764 }
765
766 return convertToSchema(items)
767}
768
769func mapJSONTypeToGoogle(jsonType string) genai.Type {
770 switch jsonType {
771 case "string":
772 return genai.TypeString
773 case "number":
774 return genai.TypeNumber
775 case "integer":
776 return genai.TypeInteger
777 case "boolean":
778 return genai.TypeBoolean
779 case "array":
780 return genai.TypeArray
781 case "object":
782 return genai.TypeObject
783 default:
784 return genai.TypeString // Default to string for unknown types
785 }
786}
787
788func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
789 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
790 return nil, errors.New("no response from model")
791 }
792
793 var (
794 content []ai.Content
795 finishReason ai.FinishReason
796 hasToolCalls bool
797 candidate = response.Candidates[0]
798 )
799
800 for _, part := range candidate.Content.Parts {
801 switch {
802 case part.Text != "":
803 if part.Thought {
804 content = append(content, ai.ReasoningContent{Text: part.Text})
805 } else {
806 content = append(content, ai.TextContent{Text: part.Text})
807 }
808 case part.FunctionCall != nil:
809 input, err := json.Marshal(part.FunctionCall.Args)
810 if err != nil {
811 return nil, err
812 }
813 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
814 content = append(content, ai.ToolCallContent{
815 ToolCallID: toolCallID,
816 ToolName: part.FunctionCall.Name,
817 Input: string(input),
818 ProviderExecuted: false,
819 })
820 hasToolCalls = true
821 default:
822 // Silently skip unknown part types instead of erroring
823 // This allows for forward compatibility with new part types
824 }
825 }
826
827 if hasToolCalls {
828 finishReason = ai.FinishReasonToolCalls
829 } else {
830 finishReason = mapFinishReason(candidate.FinishReason)
831 }
832
833 return &ai.Response{
834 Content: content,
835 Usage: mapUsage(response.UsageMetadata),
836 FinishReason: finishReason,
837 Warnings: warnings,
838 }, nil
839}
840
841func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
842 switch reason {
843 case genai.FinishReasonStop:
844 return ai.FinishReasonStop
845 case genai.FinishReasonMaxTokens:
846 return ai.FinishReasonLength
847 case genai.FinishReasonSafety,
848 genai.FinishReasonBlocklist,
849 genai.FinishReasonProhibitedContent,
850 genai.FinishReasonSPII,
851 genai.FinishReasonImageSafety:
852 return ai.FinishReasonContentFilter
853 case genai.FinishReasonRecitation,
854 genai.FinishReasonLanguage,
855 genai.FinishReasonMalformedFunctionCall:
856 return ai.FinishReasonError
857 case genai.FinishReasonOther:
858 return ai.FinishReasonOther
859 default:
860 return ai.FinishReasonUnknown
861 }
862}
863
864func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
865 return ai.Usage{
866 InputTokens: int64(usage.ToolUsePromptTokenCount),
867 OutputTokens: int64(usage.CandidatesTokenCount),
868 TotalTokens: int64(usage.TotalTokenCount),
869 ReasoningTokens: int64(usage.ThoughtsTokenCount),
870 CacheCreationTokens: int64(usage.CachedContentTokenCount),
871 CacheReadTokens: 0,
872 }
873}