1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "github.com/charmbracelet/x/exp/slice"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20// Name is the name of the Google provider.
21const Name = "google"
22
23type provider struct {
24 options options
25}
26
27// ToolCallIDFunc defines a function that generates a tool call ID.
28type ToolCallIDFunc = func() string
29
30type options struct {
31 apiKey string
32 name string
33 baseURL string
34 headers map[string]string
35 client *http.Client
36 backend genai.Backend
37 project string
38 location string
39 toolCallIDFunc ToolCallIDFunc
40}
41
42// Option defines a function that configures Google provider options.
43type Option = func(*options)
44
45// New creates a new Google provider with the given options.
46func New(opts ...Option) (fantasy.Provider, error) {
47 options := options{
48 headers: map[string]string{},
49 toolCallIDFunc: func() string {
50 return uuid.NewString()
51 },
52 }
53 for _, o := range opts {
54 o(&options)
55 }
56
57 options.name = cmp.Or(options.name, Name)
58
59 return &provider{
60 options: options,
61 }, nil
62}
63
64// WithBaseURL sets the base URL for the Google provider.
65func WithBaseURL(baseURL string) Option {
66 return func(o *options) {
67 o.baseURL = baseURL
68 }
69}
70
71// WithGeminiAPIKey sets the Gemini API key for the Google provider.
72func WithGeminiAPIKey(apiKey string) Option {
73 return func(o *options) {
74 o.backend = genai.BackendGeminiAPI
75 o.apiKey = apiKey
76 o.project = ""
77 o.location = ""
78 }
79}
80
81// WithVertex configures the Google provider to use Vertex AI.
82func WithVertex(project, location string) Option {
83 if project == "" || location == "" {
84 panic("project and location must be provided")
85 }
86 return func(o *options) {
87 o.backend = genai.BackendVertexAI
88 o.apiKey = ""
89 o.project = project
90 o.location = location
91 }
92}
93
94// WithName sets the name for the Google provider.
95func WithName(name string) Option {
96 return func(o *options) {
97 o.name = name
98 }
99}
100
101// WithHeaders sets the headers for the Google provider.
102func WithHeaders(headers map[string]string) Option {
103 return func(o *options) {
104 maps.Copy(o.headers, headers)
105 }
106}
107
108// WithHTTPClient sets the HTTP client for the Google provider.
109func WithHTTPClient(client *http.Client) Option {
110 return func(o *options) {
111 o.client = client
112 }
113}
114
115// WithToolCallIDFunc sets the function that generates a tool call ID.
116func WithToolCallIDFunc(f ToolCallIDFunc) Option {
117 return func(o *options) {
118 o.toolCallIDFunc = f
119 }
120}
121
122func (*provider) Name() string {
123 return Name
124}
125
126type languageModel struct {
127 provider string
128 modelID string
129 client *genai.Client
130 providerOptions options
131}
132
133// LanguageModel implements fantasy.Provider.
134func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
135 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
136 p, err := anthropic.New(
137 anthropic.WithVertex(a.options.project, a.options.location),
138 anthropic.WithHTTPClient(a.options.client),
139 )
140 if err != nil {
141 return nil, err
142 }
143 return p.LanguageModel(ctx, modelID)
144 }
145
146 cc := &genai.ClientConfig{
147 HTTPClient: a.options.client,
148 Backend: a.options.backend,
149 APIKey: a.options.apiKey,
150 Project: a.options.project,
151 Location: a.options.location,
152 }
153
154 if a.options.baseURL != "" || len(a.options.headers) > 0 {
155 headers := http.Header{}
156 for k, v := range a.options.headers {
157 headers.Add(k, v)
158 }
159 cc.HTTPOptions = genai.HTTPOptions{
160 BaseURL: a.options.baseURL,
161 Headers: headers,
162 }
163 }
164 client, err := genai.NewClient(ctx, cc)
165 if err != nil {
166 return nil, err
167 }
168 return &languageModel{
169 modelID: modelID,
170 provider: a.options.name,
171 providerOptions: a.options,
172 client: client,
173 }, nil
174}
175
176func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
177 config := &genai.GenerateContentConfig{}
178
179 providerOptions := &ProviderOptions{}
180 if v, ok := call.ProviderOptions[Name]; ok {
181 providerOptions, ok = v.(*ProviderOptions)
182 if !ok {
183 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
184 }
185 }
186
187 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
188
189 if providerOptions.ThinkingConfig != nil {
190 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
191 *providerOptions.ThinkingConfig.IncludeThoughts &&
192 strings.HasPrefix(g.provider, "google.vertex.") {
193 warnings = append(warnings, fantasy.CallWarning{
194 Type: fantasy.CallWarningTypeOther,
195 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
196 "and might not be supported or could behave unexpectedly with the current Google provider " +
197 fmt.Sprintf("(%s)", g.provider),
198 })
199 }
200
201 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
202 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
203 warnings = append(warnings, fantasy.CallWarning{
204 Type: fantasy.CallWarningTypeOther,
205 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
206 })
207 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
208 }
209 }
210
211 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
212
213 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
214 if len(content) > 0 && content[0].Role == genai.RoleUser {
215 systemParts := []string{}
216 for _, sp := range systemInstructions.Parts {
217 systemParts = append(systemParts, sp.Text)
218 }
219 systemMsg := strings.Join(systemParts, "\n")
220 content[0].Parts = append([]*genai.Part{
221 {
222 Text: systemMsg + "\n\n",
223 },
224 }, content[0].Parts...)
225 systemInstructions = nil
226 }
227 }
228
229 config.SystemInstruction = systemInstructions
230
231 if call.MaxOutputTokens != nil {
232 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
233 }
234
235 if call.Temperature != nil {
236 tmp := float32(*call.Temperature)
237 config.Temperature = &tmp
238 }
239 if call.TopK != nil {
240 tmp := float32(*call.TopK)
241 config.TopK = &tmp
242 }
243 if call.TopP != nil {
244 tmp := float32(*call.TopP)
245 config.TopP = &tmp
246 }
247 if call.FrequencyPenalty != nil {
248 tmp := float32(*call.FrequencyPenalty)
249 config.FrequencyPenalty = &tmp
250 }
251 if call.PresencePenalty != nil {
252 tmp := float32(*call.PresencePenalty)
253 config.PresencePenalty = &tmp
254 }
255
256 if providerOptions.ThinkingConfig != nil {
257 config.ThinkingConfig = &genai.ThinkingConfig{}
258 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
259 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
260 }
261 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
262 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
263 config.ThinkingConfig.ThinkingBudget = &tmp
264 }
265 }
266 for _, safetySetting := range providerOptions.SafetySettings {
267 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
268 Category: genai.HarmCategory(safetySetting.Category),
269 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
270 })
271 }
272 if providerOptions.CachedContent != "" {
273 config.CachedContent = providerOptions.CachedContent
274 }
275
276 if len(call.Tools) > 0 {
277 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
278 config.ToolConfig = toolChoice
279 config.Tools = append(config.Tools, &genai.Tool{
280 FunctionDeclarations: tools,
281 })
282 warnings = append(warnings, toolWarnings...)
283 }
284
285 return config, content, warnings, nil
286}
287
288func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
289 var systemInstructions *genai.Content
290 var content []*genai.Content
291 var warnings []fantasy.CallWarning
292
293 finishedSystemBlock := false
294 for _, msg := range prompt {
295 switch msg.Role {
296 case fantasy.MessageRoleSystem:
297 if finishedSystemBlock {
298 // skip multiple system messages that are separated by user/assistant messages
299 // TODO: see if we need to send error here?
300 continue
301 }
302 finishedSystemBlock = true
303
304 var systemMessages []string
305 for _, part := range msg.Content {
306 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
307 if !ok || text.Text == "" {
308 continue
309 }
310 systemMessages = append(systemMessages, text.Text)
311 }
312 if len(systemMessages) > 0 {
313 systemInstructions = &genai.Content{
314 Parts: []*genai.Part{
315 {
316 Text: strings.Join(systemMessages, "\n"),
317 },
318 },
319 }
320 }
321 case fantasy.MessageRoleUser:
322 var parts []*genai.Part
323 for _, part := range msg.Content {
324 switch part.GetType() {
325 case fantasy.ContentTypeText:
326 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
327 if !ok || text.Text == "" {
328 continue
329 }
330 parts = append(parts, &genai.Part{
331 Text: text.Text,
332 })
333 case fantasy.ContentTypeFile:
334 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
335 if !ok {
336 continue
337 }
338 parts = append(parts, &genai.Part{
339 InlineData: &genai.Blob{
340 Data: file.Data,
341 MIMEType: file.MediaType,
342 },
343 })
344 }
345 }
346 if len(parts) > 0 {
347 content = append(content, &genai.Content{
348 Role: genai.RoleUser,
349 Parts: parts,
350 })
351 }
352 case fantasy.MessageRoleAssistant:
353 var parts []*genai.Part
354 // INFO: (kujtim) this is kind of a hacky way to include thinking for google
355 // weirdly thinking needs to be included in a function call
356 var signature []byte
357 for _, part := range msg.Content {
358 switch part.GetType() {
359 case fantasy.ContentTypeText:
360 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
361 if !ok || text.Text == "" {
362 continue
363 }
364 parts = append(parts, &genai.Part{
365 Text: text.Text,
366 })
367 case fantasy.ContentTypeToolCall:
368 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
369 if !ok {
370 continue
371 }
372
373 var result map[string]any
374 err := json.Unmarshal([]byte(toolCall.Input), &result)
375 if err != nil {
376 continue
377 }
378 parts = append(parts, &genai.Part{
379 FunctionCall: &genai.FunctionCall{
380 ID: toolCall.ToolCallID,
381 Name: toolCall.ToolName,
382 Args: result,
383 },
384 ThoughtSignature: signature,
385 })
386 // reset
387 signature = nil
388 case fantasy.ContentTypeReasoning:
389 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
390 if !ok {
391 continue
392 }
393 metadata, ok := reasoning.ProviderOptions[Name]
394 if !ok {
395 continue
396 }
397 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
398 if !ok {
399 continue
400 }
401 if !ok || reasoningMetadata.Signature == "" {
402 continue
403 }
404 signature = []byte(reasoningMetadata.Signature)
405 }
406 }
407 if len(parts) > 0 {
408 content = append(content, &genai.Content{
409 Role: genai.RoleModel,
410 Parts: parts,
411 })
412 }
413 case fantasy.MessageRoleTool:
414 var parts []*genai.Part
415 for _, part := range msg.Content {
416 switch part.GetType() {
417 case fantasy.ContentTypeToolResult:
418 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
419 if !ok {
420 continue
421 }
422 var toolCall fantasy.ToolCallPart
423 for _, m := range prompt {
424 if m.Role == fantasy.MessageRoleAssistant {
425 for _, content := range m.Content {
426 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
427 if !ok {
428 continue
429 }
430 if tc.ToolCallID == result.ToolCallID {
431 toolCall = tc
432 break
433 }
434 }
435 }
436 }
437 switch result.Output.GetType() {
438 case fantasy.ToolResultContentTypeText:
439 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
440 if !ok {
441 continue
442 }
443 response := map[string]any{"result": content.Text}
444 parts = append(parts, &genai.Part{
445 FunctionResponse: &genai.FunctionResponse{
446 ID: result.ToolCallID,
447 Response: response,
448 Name: toolCall.ToolName,
449 },
450 })
451
452 case fantasy.ToolResultContentTypeError:
453 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
454 if !ok {
455 continue
456 }
457 response := map[string]any{"result": content.Error.Error()}
458 parts = append(parts, &genai.Part{
459 FunctionResponse: &genai.FunctionResponse{
460 ID: result.ToolCallID,
461 Response: response,
462 Name: toolCall.ToolName,
463 },
464 })
465 }
466 }
467 }
468 if len(parts) > 0 {
469 content = append(content, &genai.Content{
470 Role: genai.RoleUser,
471 Parts: parts,
472 })
473 }
474 default:
475 panic("unsupported message role: " + msg.Role)
476 }
477 }
478 return systemInstructions, content, warnings
479}
480
481// Generate implements fantasy.LanguageModel.
482func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
483 config, contents, warnings, err := g.prepareParams(call)
484 if err != nil {
485 return nil, err
486 }
487
488 lastMessage, history, ok := slice.Pop(contents)
489 if !ok {
490 return nil, errors.New("no messages to send")
491 }
492
493 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
494 if err != nil {
495 return nil, err
496 }
497
498 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
499 if err != nil {
500 return nil, err
501 }
502
503 return g.mapResponse(response, warnings)
504}
505
506// Model implements fantasy.LanguageModel.
507func (g *languageModel) Model() string {
508 return g.modelID
509}
510
511// Provider implements fantasy.LanguageModel.
512func (g *languageModel) Provider() string {
513 return g.provider
514}
515
516// Stream implements fantasy.LanguageModel.
517func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
518 config, contents, warnings, err := g.prepareParams(call)
519 if err != nil {
520 return nil, err
521 }
522
523 lastMessage, history, ok := slice.Pop(contents)
524 if !ok {
525 return nil, errors.New("no messages to send")
526 }
527
528 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
529 if err != nil {
530 return nil, err
531 }
532
533 return func(yield func(fantasy.StreamPart) bool) {
534 if len(warnings) > 0 {
535 if !yield(fantasy.StreamPart{
536 Type: fantasy.StreamPartTypeWarnings,
537 Warnings: warnings,
538 }) {
539 return
540 }
541 }
542
543 var currentContent string
544 var toolCalls []fantasy.ToolCallContent
545 var isActiveText bool
546 var isActiveReasoning bool
547 var blockCounter int
548 var currentTextBlockID string
549 var currentReasoningBlockID string
550 var usage *fantasy.Usage
551 var lastFinishReason fantasy.FinishReason
552
553 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
554 if err != nil {
555 yield(fantasy.StreamPart{
556 Type: fantasy.StreamPartTypeError,
557 Error: err,
558 })
559 return
560 }
561
562 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
563 for _, part := range resp.Candidates[0].Content.Parts {
564 switch {
565 case part.Text != "":
566 delta := part.Text
567 if delta != "" {
568 // Check if this is a reasoning/thought part
569 if part.Thought {
570 // End any active text block before starting reasoning
571 if isActiveText {
572 isActiveText = false
573 if !yield(fantasy.StreamPart{
574 Type: fantasy.StreamPartTypeTextEnd,
575 ID: currentTextBlockID,
576 }) {
577 return
578 }
579 }
580
581 // Start new reasoning block if not already active
582 if !isActiveReasoning {
583 isActiveReasoning = true
584 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
585 blockCounter++
586 if !yield(fantasy.StreamPart{
587 Type: fantasy.StreamPartTypeReasoningStart,
588 ID: currentReasoningBlockID,
589 }) {
590 return
591 }
592 }
593
594 if !yield(fantasy.StreamPart{
595 Type: fantasy.StreamPartTypeReasoningDelta,
596 ID: currentReasoningBlockID,
597 Delta: delta,
598 }) {
599 return
600 }
601 } else {
602 // Regular text part
603 // End any active reasoning block before starting text
604 if isActiveReasoning {
605 isActiveReasoning = false
606 metadata := &ReasoningMetadata{
607 Signature: string(part.ThoughtSignature),
608 }
609 if !yield(fantasy.StreamPart{
610 Type: fantasy.StreamPartTypeReasoningEnd,
611 ID: currentReasoningBlockID,
612 ProviderMetadata: fantasy.ProviderMetadata{
613 Name: metadata,
614 },
615 }) {
616 return
617 }
618 }
619
620 // Start new text block if not already active
621 if !isActiveText {
622 isActiveText = true
623 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
624 blockCounter++
625 if !yield(fantasy.StreamPart{
626 Type: fantasy.StreamPartTypeTextStart,
627 ID: currentTextBlockID,
628 }) {
629 return
630 }
631 }
632
633 if !yield(fantasy.StreamPart{
634 Type: fantasy.StreamPartTypeTextDelta,
635 ID: currentTextBlockID,
636 Delta: delta,
637 }) {
638 return
639 }
640 currentContent += delta
641 }
642 }
643 case part.FunctionCall != nil:
644 // End any active text or reasoning blocks
645 if isActiveText {
646 isActiveText = false
647 if !yield(fantasy.StreamPart{
648 Type: fantasy.StreamPartTypeTextEnd,
649 ID: currentTextBlockID,
650 }) {
651 return
652 }
653 }
654 if isActiveReasoning {
655 isActiveReasoning = false
656
657 metadata := &ReasoningMetadata{
658 Signature: string(part.ThoughtSignature),
659 }
660 if !yield(fantasy.StreamPart{
661 Type: fantasy.StreamPartTypeReasoningEnd,
662 ID: currentReasoningBlockID,
663 ProviderMetadata: fantasy.ProviderMetadata{
664 Name: metadata,
665 },
666 }) {
667 return
668 }
669 }
670
671 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
672
673 args, err := json.Marshal(part.FunctionCall.Args)
674 if err != nil {
675 yield(fantasy.StreamPart{
676 Type: fantasy.StreamPartTypeError,
677 Error: err,
678 })
679 return
680 }
681
682 if !yield(fantasy.StreamPart{
683 Type: fantasy.StreamPartTypeToolInputStart,
684 ID: toolCallID,
685 ToolCallName: part.FunctionCall.Name,
686 }) {
687 return
688 }
689
690 if !yield(fantasy.StreamPart{
691 Type: fantasy.StreamPartTypeToolInputDelta,
692 ID: toolCallID,
693 Delta: string(args),
694 }) {
695 return
696 }
697
698 if !yield(fantasy.StreamPart{
699 Type: fantasy.StreamPartTypeToolInputEnd,
700 ID: toolCallID,
701 }) {
702 return
703 }
704
705 if !yield(fantasy.StreamPart{
706 Type: fantasy.StreamPartTypeToolCall,
707 ID: toolCallID,
708 ToolCallName: part.FunctionCall.Name,
709 ToolCallInput: string(args),
710 ProviderExecuted: false,
711 }) {
712 return
713 }
714
715 toolCalls = append(toolCalls, fantasy.ToolCallContent{
716 ToolCallID: toolCallID,
717 ToolName: part.FunctionCall.Name,
718 Input: string(args),
719 ProviderExecuted: false,
720 })
721 }
722 }
723 }
724
725 if resp.UsageMetadata != nil {
726 currentUsage := mapUsage(resp.UsageMetadata)
727 // if first usage chunk
728 if usage == nil {
729 usage = ¤tUsage
730 } else {
731 usage.OutputTokens += currentUsage.OutputTokens
732 usage.ReasoningTokens += currentUsage.ReasoningTokens
733 usage.CacheReadTokens += currentUsage.CacheReadTokens
734 }
735 }
736
737 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
738 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
739 }
740 }
741
742 // Close any open blocks before finishing
743 if isActiveText {
744 if !yield(fantasy.StreamPart{
745 Type: fantasy.StreamPartTypeTextEnd,
746 ID: currentTextBlockID,
747 }) {
748 return
749 }
750 }
751 if isActiveReasoning {
752 if !yield(fantasy.StreamPart{
753 Type: fantasy.StreamPartTypeReasoningEnd,
754 ID: currentReasoningBlockID,
755 }) {
756 return
757 }
758 }
759
760 finishReason := lastFinishReason
761 if len(toolCalls) > 0 {
762 finishReason = fantasy.FinishReasonToolCalls
763 } else if finishReason == "" {
764 finishReason = fantasy.FinishReasonStop
765 }
766
767 yield(fantasy.StreamPart{
768 Type: fantasy.StreamPartTypeFinish,
769 Usage: *usage,
770 FinishReason: finishReason,
771 })
772 }, nil
773}
774
775func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
776 for _, tool := range tools {
777 if tool.GetType() == fantasy.ToolTypeFunction {
778 ft, ok := tool.(fantasy.FunctionTool)
779 if !ok {
780 continue
781 }
782
783 required := []string{}
784 var properties map[string]any
785 if props, ok := ft.InputSchema["properties"]; ok {
786 properties, _ = props.(map[string]any)
787 }
788 if req, ok := ft.InputSchema["required"]; ok {
789 if reqArr, ok := req.([]string); ok {
790 required = reqArr
791 }
792 }
793 declaration := &genai.FunctionDeclaration{
794 Name: ft.Name,
795 Description: ft.Description,
796 Parameters: &genai.Schema{
797 Type: genai.TypeObject,
798 Properties: convertSchemaProperties(properties),
799 Required: required,
800 },
801 }
802 googleTools = append(googleTools, declaration)
803 continue
804 }
805 // TODO: handle provider tool calls
806 warnings = append(warnings, fantasy.CallWarning{
807 Type: fantasy.CallWarningTypeUnsupportedTool,
808 Tool: tool,
809 Message: "tool is not supported",
810 })
811 }
812 if toolChoice == nil {
813 return googleTools, googleToolChoice, warnings
814 }
815 switch *toolChoice {
816 case fantasy.ToolChoiceAuto:
817 googleToolChoice = &genai.ToolConfig{
818 FunctionCallingConfig: &genai.FunctionCallingConfig{
819 Mode: genai.FunctionCallingConfigModeAuto,
820 },
821 }
822 case fantasy.ToolChoiceRequired:
823 googleToolChoice = &genai.ToolConfig{
824 FunctionCallingConfig: &genai.FunctionCallingConfig{
825 Mode: genai.FunctionCallingConfigModeAny,
826 },
827 }
828 case fantasy.ToolChoiceNone:
829 googleToolChoice = &genai.ToolConfig{
830 FunctionCallingConfig: &genai.FunctionCallingConfig{
831 Mode: genai.FunctionCallingConfigModeNone,
832 },
833 }
834 default:
835 googleToolChoice = &genai.ToolConfig{
836 FunctionCallingConfig: &genai.FunctionCallingConfig{
837 Mode: genai.FunctionCallingConfigModeAny,
838 AllowedFunctionNames: []string{
839 string(*toolChoice),
840 },
841 },
842 }
843 }
844 return googleTools, googleToolChoice, warnings
845}
846
847func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
848 properties := make(map[string]*genai.Schema)
849
850 for name, param := range parameters {
851 properties[name] = convertToSchema(param)
852 }
853
854 return properties
855}
856
857func convertToSchema(param any) *genai.Schema {
858 schema := &genai.Schema{Type: genai.TypeString}
859
860 paramMap, ok := param.(map[string]any)
861 if !ok {
862 return schema
863 }
864
865 if desc, ok := paramMap["description"].(string); ok {
866 schema.Description = desc
867 }
868
869 typeVal, hasType := paramMap["type"]
870 if !hasType {
871 return schema
872 }
873
874 typeStr, ok := typeVal.(string)
875 if !ok {
876 return schema
877 }
878
879 schema.Type = mapJSONTypeToGoogle(typeStr)
880
881 switch typeStr {
882 case "array":
883 schema.Items = processArrayItems(paramMap)
884 case "object":
885 if props, ok := paramMap["properties"].(map[string]any); ok {
886 schema.Properties = convertSchemaProperties(props)
887 }
888 }
889
890 return schema
891}
892
893func processArrayItems(paramMap map[string]any) *genai.Schema {
894 items, ok := paramMap["items"].(map[string]any)
895 if !ok {
896 return nil
897 }
898
899 return convertToSchema(items)
900}
901
902func mapJSONTypeToGoogle(jsonType string) genai.Type {
903 switch jsonType {
904 case "string":
905 return genai.TypeString
906 case "number":
907 return genai.TypeNumber
908 case "integer":
909 return genai.TypeInteger
910 case "boolean":
911 return genai.TypeBoolean
912 case "array":
913 return genai.TypeArray
914 case "object":
915 return genai.TypeObject
916 default:
917 return genai.TypeString // Default to string for unknown types
918 }
919}
920
921func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
922 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
923 return nil, errors.New("no response from model")
924 }
925
926 var (
927 content []fantasy.Content
928 finishReason fantasy.FinishReason
929 hasToolCalls bool
930 candidate = response.Candidates[0]
931 )
932
933 for _, part := range candidate.Content.Parts {
934 switch {
935 case part.Text != "":
936 if part.Thought {
937 metadata := &ReasoningMetadata{
938 Signature: string(part.ThoughtSignature),
939 }
940 content = append(content, fantasy.ReasoningContent{Text: part.Text, ProviderMetadata: fantasy.ProviderMetadata{Name: metadata}})
941 } else {
942 content = append(content, fantasy.TextContent{Text: part.Text})
943 }
944 case part.FunctionCall != nil:
945 input, err := json.Marshal(part.FunctionCall.Args)
946 if err != nil {
947 return nil, err
948 }
949 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
950 content = append(content, fantasy.ToolCallContent{
951 ToolCallID: toolCallID,
952 ToolName: part.FunctionCall.Name,
953 Input: string(input),
954 ProviderExecuted: false,
955 })
956 hasToolCalls = true
957 default:
958 // Silently skip unknown part types instead of erroring
959 // This allows for forward compatibility with new part types
960 }
961 }
962
963 if hasToolCalls {
964 finishReason = fantasy.FinishReasonToolCalls
965 } else {
966 finishReason = mapFinishReason(candidate.FinishReason)
967 }
968
969 return &fantasy.Response{
970 Content: content,
971 Usage: mapUsage(response.UsageMetadata),
972 FinishReason: finishReason,
973 Warnings: warnings,
974 }, nil
975}
976
977func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
978 switch reason {
979 case genai.FinishReasonStop:
980 return fantasy.FinishReasonStop
981 case genai.FinishReasonMaxTokens:
982 return fantasy.FinishReasonLength
983 case genai.FinishReasonSafety,
984 genai.FinishReasonBlocklist,
985 genai.FinishReasonProhibitedContent,
986 genai.FinishReasonSPII,
987 genai.FinishReasonImageSafety:
988 return fantasy.FinishReasonContentFilter
989 case genai.FinishReasonRecitation,
990 genai.FinishReasonLanguage,
991 genai.FinishReasonMalformedFunctionCall:
992 return fantasy.FinishReasonError
993 case genai.FinishReasonOther:
994 return fantasy.FinishReasonOther
995 default:
996 return fantasy.FinishReasonUnknown
997 }
998}
999
1000func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1001 return fantasy.Usage{
1002 InputTokens: int64(usage.PromptTokenCount),
1003 OutputTokens: int64(usage.CandidatesTokenCount),
1004 TotalTokens: int64(usage.TotalTokenCount),
1005 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1006 CacheCreationTokens: 0,
1007 CacheReadTokens: int64(usage.CachedContentTokenCount),
1008 }
1009}