1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 "github.com/charmbracelet/x/exp/slice"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20const Name = "google"
21
22type provider struct {
23 options options
24}
25
26type options struct {
27 apiKey string
28 name string
29 headers map[string]string
30 client *http.Client
31}
32
33type Option = func(*options)
34
35func New(opts ...Option) ai.Provider {
36 options := options{
37 headers: map[string]string{},
38 }
39 for _, o := range opts {
40 o(&options)
41 }
42
43 options.name = cmp.Or(options.name, Name)
44
45 return &provider{
46 options: options,
47 }
48}
49
50func WithAPIKey(apiKey string) Option {
51 return func(o *options) {
52 o.apiKey = apiKey
53 }
54}
55
56func WithName(name string) Option {
57 return func(o *options) {
58 o.name = name
59 }
60}
61
62func WithHeaders(headers map[string]string) Option {
63 return func(o *options) {
64 maps.Copy(o.headers, headers)
65 }
66}
67
68func WithHTTPClient(client *http.Client) Option {
69 return func(o *options) {
70 o.client = client
71 }
72}
73
74func (*provider) Name() string {
75 return Name
76}
77
78func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
79 var options ProviderOptions
80 if err := ai.ParseOptions(data, &options); err != nil {
81 return nil, err
82 }
83 return &options, nil
84}
85
86type languageModel struct {
87 provider string
88 modelID string
89 client *genai.Client
90 providerOptions options
91}
92
93// LanguageModel implements ai.Provider.
94func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
95 cc := &genai.ClientConfig{
96 APIKey: g.options.apiKey,
97 Backend: genai.BackendGeminiAPI,
98 HTTPClient: g.options.client,
99 }
100 client, err := genai.NewClient(context.Background(), cc)
101 if err != nil {
102 return nil, err
103 }
104 return &languageModel{
105 modelID: modelID,
106 provider: g.options.name,
107 providerOptions: g.options,
108 client: client,
109 }, nil
110}
111
112func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
113 config := &genai.GenerateContentConfig{}
114
115 providerOptions := &ProviderOptions{}
116 if v, ok := call.ProviderOptions[Name]; ok {
117 providerOptions, ok = v.(*ProviderOptions)
118 if !ok {
119 return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
120 }
121 }
122
123 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
124
125 if providerOptions.ThinkingConfig != nil {
126 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
127 *providerOptions.ThinkingConfig.IncludeThoughts &&
128 strings.HasPrefix(a.provider, "google.vertex.") {
129 warnings = append(warnings, ai.CallWarning{
130 Type: ai.CallWarningTypeOther,
131 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
132 "and might not be supported or could behave unexpectedly with the current Google provider " +
133 fmt.Sprintf("(%s)", a.provider),
134 })
135 }
136
137 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
138 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
139 warnings = append(warnings, ai.CallWarning{
140 Type: ai.CallWarningTypeOther,
141 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
142 })
143 providerOptions.ThinkingConfig.ThinkingBudget = ai.IntOption(128)
144 }
145 }
146
147 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
148
149 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
150 if len(content) > 0 && content[0].Role == genai.RoleUser {
151 systemParts := []string{}
152 for _, sp := range systemInstructions.Parts {
153 systemParts = append(systemParts, sp.Text)
154 }
155 systemMsg := strings.Join(systemParts, "\n")
156 content[0].Parts = append([]*genai.Part{
157 {
158 Text: systemMsg + "\n\n",
159 },
160 }, content[0].Parts...)
161 systemInstructions = nil
162 }
163 }
164
165 config.SystemInstruction = systemInstructions
166
167 if call.MaxOutputTokens != nil {
168 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
169 }
170
171 if call.Temperature != nil {
172 tmp := float32(*call.Temperature)
173 config.Temperature = &tmp
174 }
175 if call.TopK != nil {
176 tmp := float32(*call.TopK)
177 config.TopK = &tmp
178 }
179 if call.TopP != nil {
180 tmp := float32(*call.TopP)
181 config.TopP = &tmp
182 }
183 if call.FrequencyPenalty != nil {
184 tmp := float32(*call.FrequencyPenalty)
185 config.FrequencyPenalty = &tmp
186 }
187 if call.PresencePenalty != nil {
188 tmp := float32(*call.PresencePenalty)
189 config.PresencePenalty = &tmp
190 }
191
192 if providerOptions.ThinkingConfig != nil {
193 config.ThinkingConfig = &genai.ThinkingConfig{}
194 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
195 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
196 }
197 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
198 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
199 config.ThinkingConfig.ThinkingBudget = &tmp
200 }
201 }
202 for _, safetySetting := range providerOptions.SafetySettings {
203 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
204 Category: genai.HarmCategory(safetySetting.Category),
205 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
206 })
207 }
208 if providerOptions.CachedContent != "" {
209 config.CachedContent = providerOptions.CachedContent
210 }
211
212 if len(call.Tools) > 0 {
213 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
214 config.ToolConfig = toolChoice
215 config.Tools = append(config.Tools, &genai.Tool{
216 FunctionDeclarations: tools,
217 })
218 warnings = append(warnings, toolWarnings...)
219 }
220
221 return config, content, warnings, nil
222}
223
224func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
225 var systemInstructions *genai.Content
226 var content []*genai.Content
227 var warnings []ai.CallWarning
228
229 finishedSystemBlock := false
230 for _, msg := range prompt {
231 switch msg.Role {
232 case ai.MessageRoleSystem:
233 if finishedSystemBlock {
234 // skip multiple system messages that are separated by user/assistant messages
235 // TODO: see if we need to send error here?
236 continue
237 }
238 finishedSystemBlock = true
239
240 var systemMessages []string
241 for _, part := range msg.Content {
242 text, ok := ai.AsMessagePart[ai.TextPart](part)
243 if !ok || text.Text == "" {
244 continue
245 }
246 systemMessages = append(systemMessages, text.Text)
247 }
248 if len(systemMessages) > 0 {
249 systemInstructions = &genai.Content{
250 Parts: []*genai.Part{
251 {
252 Text: strings.Join(systemMessages, "\n"),
253 },
254 },
255 }
256 }
257 case ai.MessageRoleUser:
258 var parts []*genai.Part
259 for _, part := range msg.Content {
260 switch part.GetType() {
261 case ai.ContentTypeText:
262 text, ok := ai.AsMessagePart[ai.TextPart](part)
263 if !ok || text.Text == "" {
264 continue
265 }
266 parts = append(parts, &genai.Part{
267 Text: text.Text,
268 })
269 case ai.ContentTypeFile:
270 file, ok := ai.AsMessagePart[ai.FilePart](part)
271 if !ok {
272 continue
273 }
274 var encoded []byte
275 base64.StdEncoding.Encode(encoded, file.Data)
276 parts = append(parts, &genai.Part{
277 InlineData: &genai.Blob{
278 Data: encoded,
279 MIMEType: file.MediaType,
280 },
281 })
282 }
283 }
284 if len(parts) > 0 {
285 content = append(content, &genai.Content{
286 Role: genai.RoleUser,
287 Parts: parts,
288 })
289 }
290 case ai.MessageRoleAssistant:
291 var parts []*genai.Part
292 for _, part := range msg.Content {
293 switch part.GetType() {
294 case ai.ContentTypeText:
295 text, ok := ai.AsMessagePart[ai.TextPart](part)
296 if !ok || text.Text == "" {
297 continue
298 }
299 parts = append(parts, &genai.Part{
300 Text: text.Text,
301 })
302 case ai.ContentTypeToolCall:
303 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
304 if !ok {
305 continue
306 }
307
308 var result map[string]any
309 err := json.Unmarshal([]byte(toolCall.Input), &result)
310 if err != nil {
311 continue
312 }
313 parts = append(parts, &genai.Part{
314 FunctionCall: &genai.FunctionCall{
315 ID: toolCall.ToolCallID,
316 Name: toolCall.ToolName,
317 Args: result,
318 },
319 })
320 }
321 }
322 if len(parts) > 0 {
323 content = append(content, &genai.Content{
324 Role: genai.RoleModel,
325 Parts: parts,
326 })
327 }
328 case ai.MessageRoleTool:
329 var parts []*genai.Part
330 for _, part := range msg.Content {
331 switch part.GetType() {
332 case ai.ContentTypeToolResult:
333 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
334 if !ok {
335 continue
336 }
337 var toolCall ai.ToolCallPart
338 for _, m := range prompt {
339 if m.Role == ai.MessageRoleAssistant {
340 for _, content := range m.Content {
341 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
342 if !ok {
343 continue
344 }
345 if tc.ToolCallID == result.ToolCallID {
346 toolCall = tc
347 break
348 }
349 }
350 }
351 }
352 switch result.Output.GetType() {
353 case ai.ToolResultContentTypeText:
354 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
355 if !ok {
356 continue
357 }
358 response := map[string]any{"result": content.Text}
359 parts = append(parts, &genai.Part{
360 FunctionResponse: &genai.FunctionResponse{
361 ID: result.ToolCallID,
362 Response: response,
363 Name: toolCall.ToolName,
364 },
365 })
366
367 case ai.ToolResultContentTypeError:
368 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
369 if !ok {
370 continue
371 }
372 response := map[string]any{"result": content.Error.Error()}
373 parts = append(parts, &genai.Part{
374 FunctionResponse: &genai.FunctionResponse{
375 ID: result.ToolCallID,
376 Response: response,
377 Name: toolCall.ToolName,
378 },
379 })
380 }
381 }
382 }
383 if len(parts) > 0 {
384 content = append(content, &genai.Content{
385 Role: genai.RoleUser,
386 Parts: parts,
387 })
388 }
389 }
390 }
391 return systemInstructions, content, warnings
392}
393
394// Generate implements ai.LanguageModel.
395func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
396 config, contents, warnings, err := g.prepareParams(call)
397 if err != nil {
398 return nil, err
399 }
400
401 lastMessage, history, ok := slice.Pop(contents)
402 if !ok {
403 return nil, errors.New("no messages to send")
404 }
405
406 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
407 if err != nil {
408 return nil, err
409 }
410
411 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
412 if err != nil {
413 return nil, err
414 }
415
416 return mapResponse(response, warnings)
417}
418
419// Model implements ai.LanguageModel.
420func (g *languageModel) Model() string {
421 return g.modelID
422}
423
424// Provider implements ai.LanguageModel.
425func (g *languageModel) Provider() string {
426 return g.provider
427}
428
429// Stream implements ai.LanguageModel.
430func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
431 config, contents, warnings, err := g.prepareParams(call)
432 if err != nil {
433 return nil, err
434 }
435
436 lastMessage, history, ok := slice.Pop(contents)
437 if !ok {
438 return nil, errors.New("no messages to send")
439 }
440
441 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
442 if err != nil {
443 return nil, err
444 }
445
446 return func(yield func(ai.StreamPart) bool) {
447 if len(warnings) > 0 {
448 if !yield(ai.StreamPart{
449 Type: ai.StreamPartTypeWarnings,
450 Warnings: warnings,
451 }) {
452 return
453 }
454 }
455
456 var currentContent string
457 var toolCalls []ai.ToolCallContent
458 var isActiveText bool
459 var isActiveReasoning bool
460 var blockCounter int
461 var currentTextBlockID string
462 var currentReasoningBlockID string
463 var usage ai.Usage
464 var lastFinishReason ai.FinishReason
465
466 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
467 if err != nil {
468 yield(ai.StreamPart{
469 Type: ai.StreamPartTypeError,
470 Error: err,
471 })
472 return
473 }
474
475 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
476 for _, part := range resp.Candidates[0].Content.Parts {
477 switch {
478 case part.Text != "":
479 delta := part.Text
480 if delta != "" {
481 // Check if this is a reasoning/thought part
482 if part.Thought {
483 // End any active text block before starting reasoning
484 if isActiveText {
485 isActiveText = false
486 if !yield(ai.StreamPart{
487 Type: ai.StreamPartTypeTextEnd,
488 ID: currentTextBlockID,
489 }) {
490 return
491 }
492 }
493
494 // Start new reasoning block if not already active
495 if !isActiveReasoning {
496 isActiveReasoning = true
497 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
498 blockCounter++
499 if !yield(ai.StreamPart{
500 Type: ai.StreamPartTypeReasoningStart,
501 ID: currentReasoningBlockID,
502 }) {
503 return
504 }
505 }
506
507 if !yield(ai.StreamPart{
508 Type: ai.StreamPartTypeReasoningDelta,
509 ID: currentReasoningBlockID,
510 Delta: delta,
511 }) {
512 return
513 }
514 } else {
515 // Regular text part
516 // End any active reasoning block before starting text
517 if isActiveReasoning {
518 isActiveReasoning = false
519 if !yield(ai.StreamPart{
520 Type: ai.StreamPartTypeReasoningEnd,
521 ID: currentReasoningBlockID,
522 }) {
523 return
524 }
525 }
526
527 // Start new text block if not already active
528 if !isActiveText {
529 isActiveText = true
530 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
531 blockCounter++
532 if !yield(ai.StreamPart{
533 Type: ai.StreamPartTypeTextStart,
534 ID: currentTextBlockID,
535 }) {
536 return
537 }
538 }
539
540 if !yield(ai.StreamPart{
541 Type: ai.StreamPartTypeTextDelta,
542 ID: currentTextBlockID,
543 Delta: delta,
544 }) {
545 return
546 }
547 currentContent += delta
548 }
549 }
550 case part.FunctionCall != nil:
551 // End any active text or reasoning blocks
552 if isActiveText {
553 isActiveText = false
554 if !yield(ai.StreamPart{
555 Type: ai.StreamPartTypeTextEnd,
556 ID: currentTextBlockID,
557 }) {
558 return
559 }
560 }
561 if isActiveReasoning {
562 isActiveReasoning = false
563 if !yield(ai.StreamPart{
564 Type: ai.StreamPartTypeReasoningEnd,
565 ID: currentReasoningBlockID,
566 }) {
567 return
568 }
569 }
570
571 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
572
573 args, err := json.Marshal(part.FunctionCall.Args)
574 if err != nil {
575 yield(ai.StreamPart{
576 Type: ai.StreamPartTypeError,
577 Error: err,
578 })
579 return
580 }
581
582 if !yield(ai.StreamPart{
583 Type: ai.StreamPartTypeToolInputStart,
584 ID: toolCallID,
585 ToolCallName: part.FunctionCall.Name,
586 }) {
587 return
588 }
589
590 if !yield(ai.StreamPart{
591 Type: ai.StreamPartTypeToolInputDelta,
592 ID: toolCallID,
593 Delta: string(args),
594 }) {
595 return
596 }
597
598 if !yield(ai.StreamPart{
599 Type: ai.StreamPartTypeToolInputEnd,
600 ID: toolCallID,
601 }) {
602 return
603 }
604
605 if !yield(ai.StreamPart{
606 Type: ai.StreamPartTypeToolCall,
607 ID: toolCallID,
608 ToolCallName: part.FunctionCall.Name,
609 ToolCallInput: string(args),
610 ProviderExecuted: false,
611 }) {
612 return
613 }
614
615 toolCalls = append(toolCalls, ai.ToolCallContent{
616 ToolCallID: toolCallID,
617 ToolName: part.FunctionCall.Name,
618 Input: string(args),
619 ProviderExecuted: false,
620 })
621 }
622 }
623 }
624
625 if resp.UsageMetadata != nil {
626 usage = mapUsage(resp.UsageMetadata)
627 }
628
629 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
630 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
631 }
632 }
633
634 // Close any open blocks before finishing
635 if isActiveText {
636 if !yield(ai.StreamPart{
637 Type: ai.StreamPartTypeTextEnd,
638 ID: currentTextBlockID,
639 }) {
640 return
641 }
642 }
643 if isActiveReasoning {
644 if !yield(ai.StreamPart{
645 Type: ai.StreamPartTypeReasoningEnd,
646 ID: currentReasoningBlockID,
647 }) {
648 return
649 }
650 }
651
652 finishReason := lastFinishReason
653 if len(toolCalls) > 0 {
654 finishReason = ai.FinishReasonToolCalls
655 } else if finishReason == "" {
656 finishReason = ai.FinishReasonStop
657 }
658
659 yield(ai.StreamPart{
660 Type: ai.StreamPartTypeFinish,
661 Usage: usage,
662 FinishReason: finishReason,
663 })
664 }, nil
665}
666
667func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
668 for _, tool := range tools {
669 if tool.GetType() == ai.ToolTypeFunction {
670 ft, ok := tool.(ai.FunctionTool)
671 if !ok {
672 continue
673 }
674
675 required := []string{}
676 var properties map[string]any
677 if props, ok := ft.InputSchema["properties"]; ok {
678 properties, _ = props.(map[string]any)
679 }
680 if req, ok := ft.InputSchema["required"]; ok {
681 if reqArr, ok := req.([]string); ok {
682 required = reqArr
683 }
684 }
685 declaration := &genai.FunctionDeclaration{
686 Name: ft.Name,
687 Description: ft.Description,
688 Parameters: &genai.Schema{
689 Type: genai.TypeObject,
690 Properties: convertSchemaProperties(properties),
691 Required: required,
692 },
693 }
694 googleTools = append(googleTools, declaration)
695 continue
696 }
697 // TODO: handle provider tool calls
698 warnings = append(warnings, ai.CallWarning{
699 Type: ai.CallWarningTypeUnsupportedTool,
700 Tool: tool,
701 Message: "tool is not supported",
702 })
703 }
704 if toolChoice == nil {
705 return googleTools, googleToolChoice, warnings
706 }
707 switch *toolChoice {
708 case ai.ToolChoiceAuto:
709 googleToolChoice = &genai.ToolConfig{
710 FunctionCallingConfig: &genai.FunctionCallingConfig{
711 Mode: genai.FunctionCallingConfigModeAuto,
712 },
713 }
714 case ai.ToolChoiceRequired:
715 googleToolChoice = &genai.ToolConfig{
716 FunctionCallingConfig: &genai.FunctionCallingConfig{
717 Mode: genai.FunctionCallingConfigModeAny,
718 },
719 }
720 case ai.ToolChoiceNone:
721 googleToolChoice = &genai.ToolConfig{
722 FunctionCallingConfig: &genai.FunctionCallingConfig{
723 Mode: genai.FunctionCallingConfigModeNone,
724 },
725 }
726 default:
727 googleToolChoice = &genai.ToolConfig{
728 FunctionCallingConfig: &genai.FunctionCallingConfig{
729 Mode: genai.FunctionCallingConfigModeAny,
730 AllowedFunctionNames: []string{
731 string(*toolChoice),
732 },
733 },
734 }
735 }
736 return googleTools, googleToolChoice, warnings
737}
738
739func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
740 properties := make(map[string]*genai.Schema)
741
742 for name, param := range parameters {
743 properties[name] = convertToSchema(param)
744 }
745
746 return properties
747}
748
749func convertToSchema(param any) *genai.Schema {
750 schema := &genai.Schema{Type: genai.TypeString}
751
752 paramMap, ok := param.(map[string]any)
753 if !ok {
754 return schema
755 }
756
757 if desc, ok := paramMap["description"].(string); ok {
758 schema.Description = desc
759 }
760
761 typeVal, hasType := paramMap["type"]
762 if !hasType {
763 return schema
764 }
765
766 typeStr, ok := typeVal.(string)
767 if !ok {
768 return schema
769 }
770
771 schema.Type = mapJSONTypeToGoogle(typeStr)
772
773 switch typeStr {
774 case "array":
775 schema.Items = processArrayItems(paramMap)
776 case "object":
777 if props, ok := paramMap["properties"].(map[string]any); ok {
778 schema.Properties = convertSchemaProperties(props)
779 }
780 }
781
782 return schema
783}
784
785func processArrayItems(paramMap map[string]any) *genai.Schema {
786 items, ok := paramMap["items"].(map[string]any)
787 if !ok {
788 return nil
789 }
790
791 return convertToSchema(items)
792}
793
794func mapJSONTypeToGoogle(jsonType string) genai.Type {
795 switch jsonType {
796 case "string":
797 return genai.TypeString
798 case "number":
799 return genai.TypeNumber
800 case "integer":
801 return genai.TypeInteger
802 case "boolean":
803 return genai.TypeBoolean
804 case "array":
805 return genai.TypeArray
806 case "object":
807 return genai.TypeObject
808 default:
809 return genai.TypeString // Default to string for unknown types
810 }
811}
812
813func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
814 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
815 return nil, errors.New("no response from model")
816 }
817
818 var (
819 content []ai.Content
820 finishReason ai.FinishReason
821 hasToolCalls bool
822 candidate = response.Candidates[0]
823 )
824
825 for _, part := range candidate.Content.Parts {
826 switch {
827 case part.Text != "":
828 if part.Thought {
829 content = append(content, ai.ReasoningContent{Text: part.Text})
830 } else {
831 content = append(content, ai.TextContent{Text: part.Text})
832 }
833 case part.FunctionCall != nil:
834 input, err := json.Marshal(part.FunctionCall.Args)
835 if err != nil {
836 return nil, err
837 }
838 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
839 content = append(content, ai.ToolCallContent{
840 ToolCallID: toolCallID,
841 ToolName: part.FunctionCall.Name,
842 Input: string(input),
843 ProviderExecuted: false,
844 })
845 hasToolCalls = true
846 default:
847 // Silently skip unknown part types instead of erroring
848 // This allows for forward compatibility with new part types
849 }
850 }
851
852 if hasToolCalls {
853 finishReason = ai.FinishReasonToolCalls
854 } else {
855 finishReason = mapFinishReason(candidate.FinishReason)
856 }
857
858 return &ai.Response{
859 Content: content,
860 Usage: mapUsage(response.UsageMetadata),
861 FinishReason: finishReason,
862 Warnings: warnings,
863 }, nil
864}
865
866func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
867 switch reason {
868 case genai.FinishReasonStop:
869 return ai.FinishReasonStop
870 case genai.FinishReasonMaxTokens:
871 return ai.FinishReasonLength
872 case genai.FinishReasonSafety,
873 genai.FinishReasonBlocklist,
874 genai.FinishReasonProhibitedContent,
875 genai.FinishReasonSPII,
876 genai.FinishReasonImageSafety:
877 return ai.FinishReasonContentFilter
878 case genai.FinishReasonRecitation,
879 genai.FinishReasonLanguage,
880 genai.FinishReasonMalformedFunctionCall:
881 return ai.FinishReasonError
882 case genai.FinishReasonOther:
883 return ai.FinishReasonOther
884 default:
885 return ai.FinishReasonUnknown
886 }
887}
888
889func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
890 return ai.Usage{
891 InputTokens: int64(usage.ToolUsePromptTokenCount),
892 OutputTokens: int64(usage.CandidatesTokenCount),
893 TotalTokens: int64(usage.TotalTokenCount),
894 ReasoningTokens: int64(usage.ThoughtsTokenCount),
895 CacheCreationTokens: int64(usage.CachedContentTokenCount),
896 CacheReadTokens: 0,
897 }
898}