1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "cloud.google.com/go/auth"
15 "github.com/charmbracelet/fantasy/ai"
16 "github.com/charmbracelet/fantasy/anthropic"
17 "github.com/charmbracelet/x/exp/slice"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28type options struct {
29 apiKey string
30 name string
31 headers map[string]string
32 client *http.Client
33 backend genai.Backend
34 project string
35 location string
36 skipAuth bool
37}
38
39type Option = func(*options)
40
41func New(opts ...Option) ai.Provider {
42 options := options{
43 headers: map[string]string{},
44 }
45 for _, o := range opts {
46 o(&options)
47 }
48
49 options.name = cmp.Or(options.name, Name)
50
51 return &provider{
52 options: options,
53 }
54}
55
56func WithGeminiAPIKey(apiKey string) Option {
57 return func(o *options) {
58 o.backend = genai.BackendGeminiAPI
59 o.apiKey = apiKey
60 o.project = ""
61 o.location = ""
62 }
63}
64
65func WithVertex(project, location string) Option {
66 if project == "" || location == "" {
67 panic("project and location must be provided")
68 }
69 return func(o *options) {
70 o.backend = genai.BackendVertexAI
71 o.apiKey = ""
72 o.project = project
73 o.location = location
74 }
75}
76
77func WithSkipAuth(skipAuth bool) Option {
78 return func(o *options) {
79 o.skipAuth = skipAuth
80 }
81}
82
83func WithName(name string) Option {
84 return func(o *options) {
85 o.name = name
86 }
87}
88
89func WithHeaders(headers map[string]string) Option {
90 return func(o *options) {
91 maps.Copy(o.headers, headers)
92 }
93}
94
95func WithHTTPClient(client *http.Client) Option {
96 return func(o *options) {
97 o.client = client
98 }
99}
100
101func (*provider) Name() string {
102 return Name
103}
104
105func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
106 var options ProviderOptions
107 if err := ai.ParseOptions(data, &options); err != nil {
108 return nil, err
109 }
110 return &options, nil
111}
112
113type languageModel struct {
114 provider string
115 modelID string
116 client *genai.Client
117 providerOptions options
118}
119
120// LanguageModel implements ai.Provider.
121func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
122 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
123 return anthropic.New(
124 anthropic.WithVertex(g.options.project, g.options.location),
125 anthropic.WithHTTPClient(g.options.client),
126 anthropic.WithSkipGoogleAuth(g.options.skipAuth),
127 ).LanguageModel(modelID)
128 }
129
130 cc := &genai.ClientConfig{
131 HTTPClient: g.options.client,
132 Backend: g.options.backend,
133 APIKey: g.options.apiKey,
134 Project: g.options.project,
135 Location: g.options.location,
136 }
137 if g.options.skipAuth {
138 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
139 }
140 client, err := genai.NewClient(context.Background(), cc)
141 if err != nil {
142 return nil, err
143 }
144 return &languageModel{
145 modelID: modelID,
146 provider: g.options.name,
147 providerOptions: g.options,
148 client: client,
149 }, nil
150}
151
152func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
153 config := &genai.GenerateContentConfig{}
154
155 providerOptions := &ProviderOptions{}
156 if v, ok := call.ProviderOptions[Name]; ok {
157 providerOptions, ok = v.(*ProviderOptions)
158 if !ok {
159 return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
160 }
161 }
162
163 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
164
165 if providerOptions.ThinkingConfig != nil {
166 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
167 *providerOptions.ThinkingConfig.IncludeThoughts &&
168 strings.HasPrefix(a.provider, "google.vertex.") {
169 warnings = append(warnings, ai.CallWarning{
170 Type: ai.CallWarningTypeOther,
171 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
172 "and might not be supported or could behave unexpectedly with the current Google provider " +
173 fmt.Sprintf("(%s)", a.provider),
174 })
175 }
176
177 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
178 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
179 warnings = append(warnings, ai.CallWarning{
180 Type: ai.CallWarningTypeOther,
181 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
182 })
183 providerOptions.ThinkingConfig.ThinkingBudget = ai.IntOption(128)
184 }
185 }
186
187 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
188
189 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
190 if len(content) > 0 && content[0].Role == genai.RoleUser {
191 systemParts := []string{}
192 for _, sp := range systemInstructions.Parts {
193 systemParts = append(systemParts, sp.Text)
194 }
195 systemMsg := strings.Join(systemParts, "\n")
196 content[0].Parts = append([]*genai.Part{
197 {
198 Text: systemMsg + "\n\n",
199 },
200 }, content[0].Parts...)
201 systemInstructions = nil
202 }
203 }
204
205 config.SystemInstruction = systemInstructions
206
207 if call.MaxOutputTokens != nil {
208 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
209 }
210
211 if call.Temperature != nil {
212 tmp := float32(*call.Temperature)
213 config.Temperature = &tmp
214 }
215 if call.TopK != nil {
216 tmp := float32(*call.TopK)
217 config.TopK = &tmp
218 }
219 if call.TopP != nil {
220 tmp := float32(*call.TopP)
221 config.TopP = &tmp
222 }
223 if call.FrequencyPenalty != nil {
224 tmp := float32(*call.FrequencyPenalty)
225 config.FrequencyPenalty = &tmp
226 }
227 if call.PresencePenalty != nil {
228 tmp := float32(*call.PresencePenalty)
229 config.PresencePenalty = &tmp
230 }
231
232 if providerOptions.ThinkingConfig != nil {
233 config.ThinkingConfig = &genai.ThinkingConfig{}
234 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
235 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
236 }
237 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
238 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
239 config.ThinkingConfig.ThinkingBudget = &tmp
240 }
241 }
242 for _, safetySetting := range providerOptions.SafetySettings {
243 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
244 Category: genai.HarmCategory(safetySetting.Category),
245 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
246 })
247 }
248 if providerOptions.CachedContent != "" {
249 config.CachedContent = providerOptions.CachedContent
250 }
251
252 if len(call.Tools) > 0 {
253 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
254 config.ToolConfig = toolChoice
255 config.Tools = append(config.Tools, &genai.Tool{
256 FunctionDeclarations: tools,
257 })
258 warnings = append(warnings, toolWarnings...)
259 }
260
261 return config, content, warnings, nil
262}
263
264func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
265 var systemInstructions *genai.Content
266 var content []*genai.Content
267 var warnings []ai.CallWarning
268
269 finishedSystemBlock := false
270 for _, msg := range prompt {
271 switch msg.Role {
272 case ai.MessageRoleSystem:
273 if finishedSystemBlock {
274 // skip multiple system messages that are separated by user/assistant messages
275 // TODO: see if we need to send error here?
276 continue
277 }
278 finishedSystemBlock = true
279
280 var systemMessages []string
281 for _, part := range msg.Content {
282 text, ok := ai.AsMessagePart[ai.TextPart](part)
283 if !ok || text.Text == "" {
284 continue
285 }
286 systemMessages = append(systemMessages, text.Text)
287 }
288 if len(systemMessages) > 0 {
289 systemInstructions = &genai.Content{
290 Parts: []*genai.Part{
291 {
292 Text: strings.Join(systemMessages, "\n"),
293 },
294 },
295 }
296 }
297 case ai.MessageRoleUser:
298 var parts []*genai.Part
299 for _, part := range msg.Content {
300 switch part.GetType() {
301 case ai.ContentTypeText:
302 text, ok := ai.AsMessagePart[ai.TextPart](part)
303 if !ok || text.Text == "" {
304 continue
305 }
306 parts = append(parts, &genai.Part{
307 Text: text.Text,
308 })
309 case ai.ContentTypeFile:
310 file, ok := ai.AsMessagePart[ai.FilePart](part)
311 if !ok {
312 continue
313 }
314 var encoded []byte
315 base64.StdEncoding.Encode(encoded, file.Data)
316 parts = append(parts, &genai.Part{
317 InlineData: &genai.Blob{
318 Data: encoded,
319 MIMEType: file.MediaType,
320 },
321 })
322 }
323 }
324 if len(parts) > 0 {
325 content = append(content, &genai.Content{
326 Role: genai.RoleUser,
327 Parts: parts,
328 })
329 }
330 case ai.MessageRoleAssistant:
331 var parts []*genai.Part
332 for _, part := range msg.Content {
333 switch part.GetType() {
334 case ai.ContentTypeText:
335 text, ok := ai.AsMessagePart[ai.TextPart](part)
336 if !ok || text.Text == "" {
337 continue
338 }
339 parts = append(parts, &genai.Part{
340 Text: text.Text,
341 })
342 case ai.ContentTypeToolCall:
343 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
344 if !ok {
345 continue
346 }
347
348 var result map[string]any
349 err := json.Unmarshal([]byte(toolCall.Input), &result)
350 if err != nil {
351 continue
352 }
353 parts = append(parts, &genai.Part{
354 FunctionCall: &genai.FunctionCall{
355 ID: toolCall.ToolCallID,
356 Name: toolCall.ToolName,
357 Args: result,
358 },
359 })
360 }
361 }
362 if len(parts) > 0 {
363 content = append(content, &genai.Content{
364 Role: genai.RoleModel,
365 Parts: parts,
366 })
367 }
368 case ai.MessageRoleTool:
369 var parts []*genai.Part
370 for _, part := range msg.Content {
371 switch part.GetType() {
372 case ai.ContentTypeToolResult:
373 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
374 if !ok {
375 continue
376 }
377 var toolCall ai.ToolCallPart
378 for _, m := range prompt {
379 if m.Role == ai.MessageRoleAssistant {
380 for _, content := range m.Content {
381 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
382 if !ok {
383 continue
384 }
385 if tc.ToolCallID == result.ToolCallID {
386 toolCall = tc
387 break
388 }
389 }
390 }
391 }
392 switch result.Output.GetType() {
393 case ai.ToolResultContentTypeText:
394 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
395 if !ok {
396 continue
397 }
398 response := map[string]any{"result": content.Text}
399 parts = append(parts, &genai.Part{
400 FunctionResponse: &genai.FunctionResponse{
401 ID: result.ToolCallID,
402 Response: response,
403 Name: toolCall.ToolName,
404 },
405 })
406
407 case ai.ToolResultContentTypeError:
408 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
409 if !ok {
410 continue
411 }
412 response := map[string]any{"result": content.Error.Error()}
413 parts = append(parts, &genai.Part{
414 FunctionResponse: &genai.FunctionResponse{
415 ID: result.ToolCallID,
416 Response: response,
417 Name: toolCall.ToolName,
418 },
419 })
420 }
421 }
422 }
423 if len(parts) > 0 {
424 content = append(content, &genai.Content{
425 Role: genai.RoleUser,
426 Parts: parts,
427 })
428 }
429 default:
430 panic("unsupported message role: " + msg.Role)
431 }
432 }
433 return systemInstructions, content, warnings
434}
435
436// Generate implements ai.LanguageModel.
437func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
438 config, contents, warnings, err := g.prepareParams(call)
439 if err != nil {
440 return nil, err
441 }
442
443 lastMessage, history, ok := slice.Pop(contents)
444 if !ok {
445 return nil, errors.New("no messages to send")
446 }
447
448 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
449 if err != nil {
450 return nil, err
451 }
452
453 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
454 if err != nil {
455 return nil, err
456 }
457
458 return mapResponse(response, warnings)
459}
460
461// Model implements ai.LanguageModel.
462func (g *languageModel) Model() string {
463 return g.modelID
464}
465
466// Provider implements ai.LanguageModel.
467func (g *languageModel) Provider() string {
468 return g.provider
469}
470
471// Stream implements ai.LanguageModel.
472func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
473 config, contents, warnings, err := g.prepareParams(call)
474 if err != nil {
475 return nil, err
476 }
477
478 lastMessage, history, ok := slice.Pop(contents)
479 if !ok {
480 return nil, errors.New("no messages to send")
481 }
482
483 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
484 if err != nil {
485 return nil, err
486 }
487
488 return func(yield func(ai.StreamPart) bool) {
489 if len(warnings) > 0 {
490 if !yield(ai.StreamPart{
491 Type: ai.StreamPartTypeWarnings,
492 Warnings: warnings,
493 }) {
494 return
495 }
496 }
497
498 var currentContent string
499 var toolCalls []ai.ToolCallContent
500 var isActiveText bool
501 var isActiveReasoning bool
502 var blockCounter int
503 var currentTextBlockID string
504 var currentReasoningBlockID string
505 var usage ai.Usage
506 var lastFinishReason ai.FinishReason
507
508 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
509 if err != nil {
510 yield(ai.StreamPart{
511 Type: ai.StreamPartTypeError,
512 Error: err,
513 })
514 return
515 }
516
517 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
518 for _, part := range resp.Candidates[0].Content.Parts {
519 switch {
520 case part.Text != "":
521 delta := part.Text
522 if delta != "" {
523 // Check if this is a reasoning/thought part
524 if part.Thought {
525 // End any active text block before starting reasoning
526 if isActiveText {
527 isActiveText = false
528 if !yield(ai.StreamPart{
529 Type: ai.StreamPartTypeTextEnd,
530 ID: currentTextBlockID,
531 }) {
532 return
533 }
534 }
535
536 // Start new reasoning block if not already active
537 if !isActiveReasoning {
538 isActiveReasoning = true
539 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
540 blockCounter++
541 if !yield(ai.StreamPart{
542 Type: ai.StreamPartTypeReasoningStart,
543 ID: currentReasoningBlockID,
544 }) {
545 return
546 }
547 }
548
549 if !yield(ai.StreamPart{
550 Type: ai.StreamPartTypeReasoningDelta,
551 ID: currentReasoningBlockID,
552 Delta: delta,
553 }) {
554 return
555 }
556 } else {
557 // Regular text part
558 // End any active reasoning block before starting text
559 if isActiveReasoning {
560 isActiveReasoning = false
561 if !yield(ai.StreamPart{
562 Type: ai.StreamPartTypeReasoningEnd,
563 ID: currentReasoningBlockID,
564 }) {
565 return
566 }
567 }
568
569 // Start new text block if not already active
570 if !isActiveText {
571 isActiveText = true
572 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
573 blockCounter++
574 if !yield(ai.StreamPart{
575 Type: ai.StreamPartTypeTextStart,
576 ID: currentTextBlockID,
577 }) {
578 return
579 }
580 }
581
582 if !yield(ai.StreamPart{
583 Type: ai.StreamPartTypeTextDelta,
584 ID: currentTextBlockID,
585 Delta: delta,
586 }) {
587 return
588 }
589 currentContent += delta
590 }
591 }
592 case part.FunctionCall != nil:
593 // End any active text or reasoning blocks
594 if isActiveText {
595 isActiveText = false
596 if !yield(ai.StreamPart{
597 Type: ai.StreamPartTypeTextEnd,
598 ID: currentTextBlockID,
599 }) {
600 return
601 }
602 }
603 if isActiveReasoning {
604 isActiveReasoning = false
605 if !yield(ai.StreamPart{
606 Type: ai.StreamPartTypeReasoningEnd,
607 ID: currentReasoningBlockID,
608 }) {
609 return
610 }
611 }
612
613 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
614
615 args, err := json.Marshal(part.FunctionCall.Args)
616 if err != nil {
617 yield(ai.StreamPart{
618 Type: ai.StreamPartTypeError,
619 Error: err,
620 })
621 return
622 }
623
624 if !yield(ai.StreamPart{
625 Type: ai.StreamPartTypeToolInputStart,
626 ID: toolCallID,
627 ToolCallName: part.FunctionCall.Name,
628 }) {
629 return
630 }
631
632 if !yield(ai.StreamPart{
633 Type: ai.StreamPartTypeToolInputDelta,
634 ID: toolCallID,
635 Delta: string(args),
636 }) {
637 return
638 }
639
640 if !yield(ai.StreamPart{
641 Type: ai.StreamPartTypeToolInputEnd,
642 ID: toolCallID,
643 }) {
644 return
645 }
646
647 if !yield(ai.StreamPart{
648 Type: ai.StreamPartTypeToolCall,
649 ID: toolCallID,
650 ToolCallName: part.FunctionCall.Name,
651 ToolCallInput: string(args),
652 ProviderExecuted: false,
653 }) {
654 return
655 }
656
657 toolCalls = append(toolCalls, ai.ToolCallContent{
658 ToolCallID: toolCallID,
659 ToolName: part.FunctionCall.Name,
660 Input: string(args),
661 ProviderExecuted: false,
662 })
663 }
664 }
665 }
666
667 if resp.UsageMetadata != nil {
668 usage = mapUsage(resp.UsageMetadata)
669 }
670
671 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
672 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
673 }
674 }
675
676 // Close any open blocks before finishing
677 if isActiveText {
678 if !yield(ai.StreamPart{
679 Type: ai.StreamPartTypeTextEnd,
680 ID: currentTextBlockID,
681 }) {
682 return
683 }
684 }
685 if isActiveReasoning {
686 if !yield(ai.StreamPart{
687 Type: ai.StreamPartTypeReasoningEnd,
688 ID: currentReasoningBlockID,
689 }) {
690 return
691 }
692 }
693
694 finishReason := lastFinishReason
695 if len(toolCalls) > 0 {
696 finishReason = ai.FinishReasonToolCalls
697 } else if finishReason == "" {
698 finishReason = ai.FinishReasonStop
699 }
700
701 yield(ai.StreamPart{
702 Type: ai.StreamPartTypeFinish,
703 Usage: usage,
704 FinishReason: finishReason,
705 })
706 }, nil
707}
708
709func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
710 for _, tool := range tools {
711 if tool.GetType() == ai.ToolTypeFunction {
712 ft, ok := tool.(ai.FunctionTool)
713 if !ok {
714 continue
715 }
716
717 required := []string{}
718 var properties map[string]any
719 if props, ok := ft.InputSchema["properties"]; ok {
720 properties, _ = props.(map[string]any)
721 }
722 if req, ok := ft.InputSchema["required"]; ok {
723 if reqArr, ok := req.([]string); ok {
724 required = reqArr
725 }
726 }
727 declaration := &genai.FunctionDeclaration{
728 Name: ft.Name,
729 Description: ft.Description,
730 Parameters: &genai.Schema{
731 Type: genai.TypeObject,
732 Properties: convertSchemaProperties(properties),
733 Required: required,
734 },
735 }
736 googleTools = append(googleTools, declaration)
737 continue
738 }
739 // TODO: handle provider tool calls
740 warnings = append(warnings, ai.CallWarning{
741 Type: ai.CallWarningTypeUnsupportedTool,
742 Tool: tool,
743 Message: "tool is not supported",
744 })
745 }
746 if toolChoice == nil {
747 return googleTools, googleToolChoice, warnings
748 }
749 switch *toolChoice {
750 case ai.ToolChoiceAuto:
751 googleToolChoice = &genai.ToolConfig{
752 FunctionCallingConfig: &genai.FunctionCallingConfig{
753 Mode: genai.FunctionCallingConfigModeAuto,
754 },
755 }
756 case ai.ToolChoiceRequired:
757 googleToolChoice = &genai.ToolConfig{
758 FunctionCallingConfig: &genai.FunctionCallingConfig{
759 Mode: genai.FunctionCallingConfigModeAny,
760 },
761 }
762 case ai.ToolChoiceNone:
763 googleToolChoice = &genai.ToolConfig{
764 FunctionCallingConfig: &genai.FunctionCallingConfig{
765 Mode: genai.FunctionCallingConfigModeNone,
766 },
767 }
768 default:
769 googleToolChoice = &genai.ToolConfig{
770 FunctionCallingConfig: &genai.FunctionCallingConfig{
771 Mode: genai.FunctionCallingConfigModeAny,
772 AllowedFunctionNames: []string{
773 string(*toolChoice),
774 },
775 },
776 }
777 }
778 return googleTools, googleToolChoice, warnings
779}
780
781func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
782 properties := make(map[string]*genai.Schema)
783
784 for name, param := range parameters {
785 properties[name] = convertToSchema(param)
786 }
787
788 return properties
789}
790
791func convertToSchema(param any) *genai.Schema {
792 schema := &genai.Schema{Type: genai.TypeString}
793
794 paramMap, ok := param.(map[string]any)
795 if !ok {
796 return schema
797 }
798
799 if desc, ok := paramMap["description"].(string); ok {
800 schema.Description = desc
801 }
802
803 typeVal, hasType := paramMap["type"]
804 if !hasType {
805 return schema
806 }
807
808 typeStr, ok := typeVal.(string)
809 if !ok {
810 return schema
811 }
812
813 schema.Type = mapJSONTypeToGoogle(typeStr)
814
815 switch typeStr {
816 case "array":
817 schema.Items = processArrayItems(paramMap)
818 case "object":
819 if props, ok := paramMap["properties"].(map[string]any); ok {
820 schema.Properties = convertSchemaProperties(props)
821 }
822 }
823
824 return schema
825}
826
827func processArrayItems(paramMap map[string]any) *genai.Schema {
828 items, ok := paramMap["items"].(map[string]any)
829 if !ok {
830 return nil
831 }
832
833 return convertToSchema(items)
834}
835
836func mapJSONTypeToGoogle(jsonType string) genai.Type {
837 switch jsonType {
838 case "string":
839 return genai.TypeString
840 case "number":
841 return genai.TypeNumber
842 case "integer":
843 return genai.TypeInteger
844 case "boolean":
845 return genai.TypeBoolean
846 case "array":
847 return genai.TypeArray
848 case "object":
849 return genai.TypeObject
850 default:
851 return genai.TypeString // Default to string for unknown types
852 }
853}
854
855func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
856 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
857 return nil, errors.New("no response from model")
858 }
859
860 var (
861 content []ai.Content
862 finishReason ai.FinishReason
863 hasToolCalls bool
864 candidate = response.Candidates[0]
865 )
866
867 for _, part := range candidate.Content.Parts {
868 switch {
869 case part.Text != "":
870 if part.Thought {
871 content = append(content, ai.ReasoningContent{Text: part.Text})
872 } else {
873 content = append(content, ai.TextContent{Text: part.Text})
874 }
875 case part.FunctionCall != nil:
876 input, err := json.Marshal(part.FunctionCall.Args)
877 if err != nil {
878 return nil, err
879 }
880 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
881 content = append(content, ai.ToolCallContent{
882 ToolCallID: toolCallID,
883 ToolName: part.FunctionCall.Name,
884 Input: string(input),
885 ProviderExecuted: false,
886 })
887 hasToolCalls = true
888 default:
889 // Silently skip unknown part types instead of erroring
890 // This allows for forward compatibility with new part types
891 }
892 }
893
894 if hasToolCalls {
895 finishReason = ai.FinishReasonToolCalls
896 } else {
897 finishReason = mapFinishReason(candidate.FinishReason)
898 }
899
900 return &ai.Response{
901 Content: content,
902 Usage: mapUsage(response.UsageMetadata),
903 FinishReason: finishReason,
904 Warnings: warnings,
905 }, nil
906}
907
908func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
909 switch reason {
910 case genai.FinishReasonStop:
911 return ai.FinishReasonStop
912 case genai.FinishReasonMaxTokens:
913 return ai.FinishReasonLength
914 case genai.FinishReasonSafety,
915 genai.FinishReasonBlocklist,
916 genai.FinishReasonProhibitedContent,
917 genai.FinishReasonSPII,
918 genai.FinishReasonImageSafety:
919 return ai.FinishReasonContentFilter
920 case genai.FinishReasonRecitation,
921 genai.FinishReasonLanguage,
922 genai.FinishReasonMalformedFunctionCall:
923 return ai.FinishReasonError
924 case genai.FinishReasonOther:
925 return ai.FinishReasonOther
926 default:
927 return ai.FinishReasonUnknown
928 }
929}
930
931func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
932 return ai.Usage{
933 InputTokens: int64(usage.ToolUsePromptTokenCount),
934 OutputTokens: int64(usage.CandidatesTokenCount),
935 TotalTokens: int64(usage.TotalTokenCount),
936 ReasoningTokens: int64(usage.ThoughtsTokenCount),
937 CacheCreationTokens: int64(usage.CachedContentTokenCount),
938 CacheReadTokens: 0,
939 }
940}