1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "cloud.google.com/go/auth"
16 "github.com/charmbracelet/go-genai"
17 "github.com/charmbracelet/x/exp/slice"
18 "github.com/google/uuid"
19)
20
21// Name is the name of the Google provider.
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28// ToolCallIDFunc defines a function that generates a tool call ID.
29type ToolCallIDFunc = func() string
30
31type options struct {
32 apiKey string
33 name string
34 baseURL string
35 headers map[string]string
36 client *http.Client
37 backend genai.Backend
38 project string
39 location string
40 skipAuth bool
41 toolCallIDFunc ToolCallIDFunc
42}
43
44// Option defines a function that configures Google provider options.
45type Option = func(*options)
46
47// New creates a new Google provider with the given options.
48func New(opts ...Option) (fantasy.Provider, error) {
49 options := options{
50 headers: map[string]string{},
51 toolCallIDFunc: func() string {
52 return uuid.NewString()
53 },
54 }
55 for _, o := range opts {
56 o(&options)
57 }
58
59 options.name = cmp.Or(options.name, Name)
60
61 return &provider{
62 options: options,
63 }, nil
64}
65
66// WithBaseURL sets the base URL for the Google provider.
67func WithBaseURL(baseURL string) Option {
68 return func(o *options) {
69 o.baseURL = baseURL
70 }
71}
72
73// WithGeminiAPIKey sets the Gemini API key for the Google provider.
74func WithGeminiAPIKey(apiKey string) Option {
75 return func(o *options) {
76 o.backend = genai.BackendGeminiAPI
77 o.apiKey = apiKey
78 o.project = ""
79 o.location = ""
80 }
81}
82
83// WithVertex configures the Google provider to use Vertex AI.
84func WithVertex(project, location string) Option {
85 if project == "" || location == "" {
86 panic("project and location must be provided")
87 }
88 return func(o *options) {
89 o.backend = genai.BackendVertexAI
90 o.apiKey = ""
91 o.project = project
92 o.location = location
93 }
94}
95
96// WithSkipAuth configures whether to skip authentication for the Google provider.
97func WithSkipAuth(skipAuth bool) Option {
98 return func(o *options) {
99 o.skipAuth = skipAuth
100 }
101}
102
103// WithName sets the name for the Google provider.
104func WithName(name string) Option {
105 return func(o *options) {
106 o.name = name
107 }
108}
109
110// WithHeaders sets the headers for the Google provider.
111func WithHeaders(headers map[string]string) Option {
112 return func(o *options) {
113 maps.Copy(o.headers, headers)
114 }
115}
116
117// WithHTTPClient sets the HTTP client for the Google provider.
118func WithHTTPClient(client *http.Client) Option {
119 return func(o *options) {
120 o.client = client
121 }
122}
123
124// WithToolCallIDFunc sets the function that generates a tool call ID.
125func WithToolCallIDFunc(f ToolCallIDFunc) Option {
126 return func(o *options) {
127 o.toolCallIDFunc = f
128 }
129}
130
131func (*provider) Name() string {
132 return Name
133}
134
135type languageModel struct {
136 provider string
137 modelID string
138 client *genai.Client
139 providerOptions options
140}
141
142// LanguageModel implements fantasy.Provider.
143func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
144 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
145 p, err := anthropic.New(
146 anthropic.WithVertex(a.options.project, a.options.location),
147 anthropic.WithHTTPClient(a.options.client),
148 anthropic.WithSkipAuth(a.options.skipAuth),
149 )
150 if err != nil {
151 return nil, err
152 }
153 return p.LanguageModel(ctx, modelID)
154 }
155
156 cc := &genai.ClientConfig{
157 HTTPClient: a.options.client,
158 Backend: a.options.backend,
159 APIKey: a.options.apiKey,
160 Project: a.options.project,
161 Location: a.options.location,
162 }
163 if a.options.skipAuth {
164 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
165 }
166
167 if a.options.baseURL != "" || len(a.options.headers) > 0 {
168 headers := http.Header{}
169 for k, v := range a.options.headers {
170 headers.Add(k, v)
171 }
172 cc.HTTPOptions = genai.HTTPOptions{
173 BaseURL: a.options.baseURL,
174 Headers: headers,
175 }
176 }
177 client, err := genai.NewClient(ctx, cc)
178 if err != nil {
179 return nil, err
180 }
181 return &languageModel{
182 modelID: modelID,
183 provider: a.options.name,
184 providerOptions: a.options,
185 client: client,
186 }, nil
187}
188
189func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
190 config := &genai.GenerateContentConfig{}
191
192 providerOptions := &ProviderOptions{}
193 if v, ok := call.ProviderOptions[Name]; ok {
194 providerOptions, ok = v.(*ProviderOptions)
195 if !ok {
196 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
197 }
198 }
199
200 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
201
202 if providerOptions.ThinkingConfig != nil {
203 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
204 *providerOptions.ThinkingConfig.IncludeThoughts &&
205 strings.HasPrefix(g.provider, "google.vertex.") {
206 warnings = append(warnings, fantasy.CallWarning{
207 Type: fantasy.CallWarningTypeOther,
208 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
209 "and might not be supported or could behave unexpectedly with the current Google provider " +
210 fmt.Sprintf("(%s)", g.provider),
211 })
212 }
213
214 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
215 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
216 warnings = append(warnings, fantasy.CallWarning{
217 Type: fantasy.CallWarningTypeOther,
218 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
219 })
220 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
221 }
222 }
223
224 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
225
226 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
227 if len(content) > 0 && content[0].Role == genai.RoleUser {
228 systemParts := []string{}
229 for _, sp := range systemInstructions.Parts {
230 systemParts = append(systemParts, sp.Text)
231 }
232 systemMsg := strings.Join(systemParts, "\n")
233 content[0].Parts = append([]*genai.Part{
234 {
235 Text: systemMsg + "\n\n",
236 },
237 }, content[0].Parts...)
238 systemInstructions = nil
239 }
240 }
241
242 config.SystemInstruction = systemInstructions
243
244 if call.MaxOutputTokens != nil {
245 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
246 }
247
248 if call.Temperature != nil {
249 tmp := float32(*call.Temperature)
250 config.Temperature = &tmp
251 }
252 if call.TopK != nil {
253 tmp := float32(*call.TopK)
254 config.TopK = &tmp
255 }
256 if call.TopP != nil {
257 tmp := float32(*call.TopP)
258 config.TopP = &tmp
259 }
260 if call.FrequencyPenalty != nil {
261 tmp := float32(*call.FrequencyPenalty)
262 config.FrequencyPenalty = &tmp
263 }
264 if call.PresencePenalty != nil {
265 tmp := float32(*call.PresencePenalty)
266 config.PresencePenalty = &tmp
267 }
268
269 if providerOptions.ThinkingConfig != nil {
270 config.ThinkingConfig = &genai.ThinkingConfig{}
271 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
272 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
273 }
274 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
275 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
276 config.ThinkingConfig.ThinkingBudget = &tmp
277 }
278 }
279 for _, safetySetting := range providerOptions.SafetySettings {
280 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
281 Category: genai.HarmCategory(safetySetting.Category),
282 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
283 })
284 }
285 if providerOptions.CachedContent != "" {
286 config.CachedContent = providerOptions.CachedContent
287 }
288
289 if len(call.Tools) > 0 {
290 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
291 config.ToolConfig = toolChoice
292 config.Tools = append(config.Tools, &genai.Tool{
293 FunctionDeclarations: tools,
294 })
295 warnings = append(warnings, toolWarnings...)
296 }
297
298 return config, content, warnings, nil
299}
300
301func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
302 var systemInstructions *genai.Content
303 var content []*genai.Content
304 var warnings []fantasy.CallWarning
305
306 finishedSystemBlock := false
307 for _, msg := range prompt {
308 switch msg.Role {
309 case fantasy.MessageRoleSystem:
310 if finishedSystemBlock {
311 // skip multiple system messages that are separated by user/assistant messages
312 // TODO: see if we need to send error here?
313 continue
314 }
315 finishedSystemBlock = true
316
317 var systemMessages []string
318 for _, part := range msg.Content {
319 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
320 if !ok || text.Text == "" {
321 continue
322 }
323 systemMessages = append(systemMessages, text.Text)
324 }
325 if len(systemMessages) > 0 {
326 systemInstructions = &genai.Content{
327 Parts: []*genai.Part{
328 {
329 Text: strings.Join(systemMessages, "\n"),
330 },
331 },
332 }
333 }
334 case fantasy.MessageRoleUser:
335 var parts []*genai.Part
336 for _, part := range msg.Content {
337 switch part.GetType() {
338 case fantasy.ContentTypeText:
339 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
340 if !ok || text.Text == "" {
341 continue
342 }
343 parts = append(parts, &genai.Part{
344 Text: text.Text,
345 })
346 case fantasy.ContentTypeFile:
347 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
348 if !ok {
349 continue
350 }
351 parts = append(parts, &genai.Part{
352 InlineData: &genai.Blob{
353 Data: file.Data,
354 MIMEType: file.MediaType,
355 },
356 })
357 }
358 }
359 if len(parts) > 0 {
360 content = append(content, &genai.Content{
361 Role: genai.RoleUser,
362 Parts: parts,
363 })
364 }
365 case fantasy.MessageRoleAssistant:
366 var parts []*genai.Part
367 // INFO: (kujtim) this is kind of a hacky way to include thinking for google
368 // weirdly thinking needs to be included in a function call
369 var signature []byte
370 for _, part := range msg.Content {
371 switch part.GetType() {
372 case fantasy.ContentTypeText:
373 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
374 if !ok || text.Text == "" {
375 continue
376 }
377 parts = append(parts, &genai.Part{
378 Text: text.Text,
379 })
380 case fantasy.ContentTypeToolCall:
381 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
382 if !ok {
383 continue
384 }
385
386 var result map[string]any
387 err := json.Unmarshal([]byte(toolCall.Input), &result)
388 if err != nil {
389 continue
390 }
391 parts = append(parts, &genai.Part{
392 FunctionCall: &genai.FunctionCall{
393 ID: toolCall.ToolCallID,
394 Name: toolCall.ToolName,
395 Args: result,
396 },
397 ThoughtSignature: signature,
398 })
399 // reset
400 signature = nil
401 case fantasy.ContentTypeReasoning:
402 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
403 if !ok {
404 continue
405 }
406 metadata, ok := reasoning.ProviderOptions[Name]
407 if !ok {
408 continue
409 }
410 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
411 if !ok {
412 continue
413 }
414 if !ok || reasoningMetadata.Signature == "" {
415 continue
416 }
417 signature = []byte(reasoningMetadata.Signature)
418 }
419 }
420 if len(parts) > 0 {
421 content = append(content, &genai.Content{
422 Role: genai.RoleModel,
423 Parts: parts,
424 })
425 }
426 case fantasy.MessageRoleTool:
427 var parts []*genai.Part
428 for _, part := range msg.Content {
429 switch part.GetType() {
430 case fantasy.ContentTypeToolResult:
431 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
432 if !ok {
433 continue
434 }
435 var toolCall fantasy.ToolCallPart
436 for _, m := range prompt {
437 if m.Role == fantasy.MessageRoleAssistant {
438 for _, content := range m.Content {
439 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
440 if !ok {
441 continue
442 }
443 if tc.ToolCallID == result.ToolCallID {
444 toolCall = tc
445 break
446 }
447 }
448 }
449 }
450 switch result.Output.GetType() {
451 case fantasy.ToolResultContentTypeText:
452 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
453 if !ok {
454 continue
455 }
456 response := map[string]any{"result": content.Text}
457 parts = append(parts, &genai.Part{
458 FunctionResponse: &genai.FunctionResponse{
459 ID: result.ToolCallID,
460 Response: response,
461 Name: toolCall.ToolName,
462 },
463 })
464
465 case fantasy.ToolResultContentTypeError:
466 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
467 if !ok {
468 continue
469 }
470 response := map[string]any{"result": content.Error.Error()}
471 parts = append(parts, &genai.Part{
472 FunctionResponse: &genai.FunctionResponse{
473 ID: result.ToolCallID,
474 Response: response,
475 Name: toolCall.ToolName,
476 },
477 })
478 }
479 }
480 }
481 if len(parts) > 0 {
482 content = append(content, &genai.Content{
483 Role: genai.RoleUser,
484 Parts: parts,
485 })
486 }
487 default:
488 panic("unsupported message role: " + msg.Role)
489 }
490 }
491 return systemInstructions, content, warnings
492}
493
494// Generate implements fantasy.LanguageModel.
495func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
496 config, contents, warnings, err := g.prepareParams(call)
497 if err != nil {
498 return nil, err
499 }
500
501 lastMessage, history, ok := slice.Pop(contents)
502 if !ok {
503 return nil, errors.New("no messages to send")
504 }
505
506 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
507 if err != nil {
508 return nil, err
509 }
510
511 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
512 if err != nil {
513 return nil, err
514 }
515
516 return g.mapResponse(response, warnings)
517}
518
519// Model implements fantasy.LanguageModel.
520func (g *languageModel) Model() string {
521 return g.modelID
522}
523
524// Provider implements fantasy.LanguageModel.
525func (g *languageModel) Provider() string {
526 return g.provider
527}
528
529// Stream implements fantasy.LanguageModel.
530func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
531 config, contents, warnings, err := g.prepareParams(call)
532 if err != nil {
533 return nil, err
534 }
535
536 lastMessage, history, ok := slice.Pop(contents)
537 if !ok {
538 return nil, errors.New("no messages to send")
539 }
540
541 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
542 if err != nil {
543 return nil, err
544 }
545
546 return func(yield func(fantasy.StreamPart) bool) {
547 if len(warnings) > 0 {
548 if !yield(fantasy.StreamPart{
549 Type: fantasy.StreamPartTypeWarnings,
550 Warnings: warnings,
551 }) {
552 return
553 }
554 }
555
556 var currentContent string
557 var toolCalls []fantasy.ToolCallContent
558 var isActiveText bool
559 var isActiveReasoning bool
560 var blockCounter int
561 var currentTextBlockID string
562 var currentReasoningBlockID string
563 var usage *fantasy.Usage
564 var lastFinishReason fantasy.FinishReason
565
566 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
567 if err != nil {
568 yield(fantasy.StreamPart{
569 Type: fantasy.StreamPartTypeError,
570 Error: err,
571 })
572 return
573 }
574
575 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
576 for _, part := range resp.Candidates[0].Content.Parts {
577 switch {
578 case part.Text != "":
579 delta := part.Text
580 if delta != "" {
581 // Check if this is a reasoning/thought part
582 if part.Thought {
583 // End any active text block before starting reasoning
584 if isActiveText {
585 isActiveText = false
586 if !yield(fantasy.StreamPart{
587 Type: fantasy.StreamPartTypeTextEnd,
588 ID: currentTextBlockID,
589 }) {
590 return
591 }
592 }
593
594 // Start new reasoning block if not already active
595 if !isActiveReasoning {
596 isActiveReasoning = true
597 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
598 blockCounter++
599 if !yield(fantasy.StreamPart{
600 Type: fantasy.StreamPartTypeReasoningStart,
601 ID: currentReasoningBlockID,
602 }) {
603 return
604 }
605 }
606
607 if !yield(fantasy.StreamPart{
608 Type: fantasy.StreamPartTypeReasoningDelta,
609 ID: currentReasoningBlockID,
610 Delta: delta,
611 }) {
612 return
613 }
614 } else {
615 // Regular text part
616 // End any active reasoning block before starting text
617 if isActiveReasoning {
618 isActiveReasoning = false
619 metadata := &ReasoningMetadata{
620 Signature: string(part.ThoughtSignature),
621 }
622 if !yield(fantasy.StreamPart{
623 Type: fantasy.StreamPartTypeReasoningEnd,
624 ID: currentReasoningBlockID,
625 ProviderMetadata: fantasy.ProviderMetadata{
626 Name: metadata,
627 },
628 }) {
629 return
630 }
631 }
632
633 // Start new text block if not already active
634 if !isActiveText {
635 isActiveText = true
636 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
637 blockCounter++
638 if !yield(fantasy.StreamPart{
639 Type: fantasy.StreamPartTypeTextStart,
640 ID: currentTextBlockID,
641 }) {
642 return
643 }
644 }
645
646 if !yield(fantasy.StreamPart{
647 Type: fantasy.StreamPartTypeTextDelta,
648 ID: currentTextBlockID,
649 Delta: delta,
650 }) {
651 return
652 }
653 currentContent += delta
654 }
655 }
656 case part.FunctionCall != nil:
657 // End any active text or reasoning blocks
658 if isActiveText {
659 isActiveText = false
660 if !yield(fantasy.StreamPart{
661 Type: fantasy.StreamPartTypeTextEnd,
662 ID: currentTextBlockID,
663 }) {
664 return
665 }
666 }
667 if isActiveReasoning {
668 isActiveReasoning = false
669
670 metadata := &ReasoningMetadata{
671 Signature: string(part.ThoughtSignature),
672 }
673 if !yield(fantasy.StreamPart{
674 Type: fantasy.StreamPartTypeReasoningEnd,
675 ID: currentReasoningBlockID,
676 ProviderMetadata: fantasy.ProviderMetadata{
677 Name: metadata,
678 },
679 }) {
680 return
681 }
682 }
683
684 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
685
686 args, err := json.Marshal(part.FunctionCall.Args)
687 if err != nil {
688 yield(fantasy.StreamPart{
689 Type: fantasy.StreamPartTypeError,
690 Error: err,
691 })
692 return
693 }
694
695 if !yield(fantasy.StreamPart{
696 Type: fantasy.StreamPartTypeToolInputStart,
697 ID: toolCallID,
698 ToolCallName: part.FunctionCall.Name,
699 }) {
700 return
701 }
702
703 if !yield(fantasy.StreamPart{
704 Type: fantasy.StreamPartTypeToolInputDelta,
705 ID: toolCallID,
706 Delta: string(args),
707 }) {
708 return
709 }
710
711 if !yield(fantasy.StreamPart{
712 Type: fantasy.StreamPartTypeToolInputEnd,
713 ID: toolCallID,
714 }) {
715 return
716 }
717
718 if !yield(fantasy.StreamPart{
719 Type: fantasy.StreamPartTypeToolCall,
720 ID: toolCallID,
721 ToolCallName: part.FunctionCall.Name,
722 ToolCallInput: string(args),
723 ProviderExecuted: false,
724 }) {
725 return
726 }
727
728 toolCalls = append(toolCalls, fantasy.ToolCallContent{
729 ToolCallID: toolCallID,
730 ToolName: part.FunctionCall.Name,
731 Input: string(args),
732 ProviderExecuted: false,
733 })
734 }
735 }
736 }
737
738 if resp.UsageMetadata != nil {
739 currentUsage := mapUsage(resp.UsageMetadata)
740 // if first usage chunk
741 if usage == nil {
742 usage = ¤tUsage
743 } else {
744 usage.OutputTokens += currentUsage.OutputTokens
745 usage.ReasoningTokens += currentUsage.ReasoningTokens
746 usage.CacheReadTokens += currentUsage.CacheReadTokens
747 }
748 }
749
750 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
751 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
752 }
753 }
754
755 // Close any open blocks before finishing
756 if isActiveText {
757 if !yield(fantasy.StreamPart{
758 Type: fantasy.StreamPartTypeTextEnd,
759 ID: currentTextBlockID,
760 }) {
761 return
762 }
763 }
764 if isActiveReasoning {
765 if !yield(fantasy.StreamPart{
766 Type: fantasy.StreamPartTypeReasoningEnd,
767 ID: currentReasoningBlockID,
768 }) {
769 return
770 }
771 }
772
773 finishReason := lastFinishReason
774 if len(toolCalls) > 0 {
775 finishReason = fantasy.FinishReasonToolCalls
776 } else if finishReason == "" {
777 finishReason = fantasy.FinishReasonStop
778 }
779
780 yield(fantasy.StreamPart{
781 Type: fantasy.StreamPartTypeFinish,
782 Usage: *usage,
783 FinishReason: finishReason,
784 })
785 }, nil
786}
787
788func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
789 for _, tool := range tools {
790 if tool.GetType() == fantasy.ToolTypeFunction {
791 ft, ok := tool.(fantasy.FunctionTool)
792 if !ok {
793 continue
794 }
795
796 required := []string{}
797 var properties map[string]any
798 if props, ok := ft.InputSchema["properties"]; ok {
799 properties, _ = props.(map[string]any)
800 }
801 if req, ok := ft.InputSchema["required"]; ok {
802 if reqArr, ok := req.([]string); ok {
803 required = reqArr
804 }
805 }
806 declaration := &genai.FunctionDeclaration{
807 Name: ft.Name,
808 Description: ft.Description,
809 Parameters: &genai.Schema{
810 Type: genai.TypeObject,
811 Properties: convertSchemaProperties(properties),
812 Required: required,
813 },
814 }
815 googleTools = append(googleTools, declaration)
816 continue
817 }
818 // TODO: handle provider tool calls
819 warnings = append(warnings, fantasy.CallWarning{
820 Type: fantasy.CallWarningTypeUnsupportedTool,
821 Tool: tool,
822 Message: "tool is not supported",
823 })
824 }
825 if toolChoice == nil {
826 return googleTools, googleToolChoice, warnings
827 }
828 switch *toolChoice {
829 case fantasy.ToolChoiceAuto:
830 googleToolChoice = &genai.ToolConfig{
831 FunctionCallingConfig: &genai.FunctionCallingConfig{
832 Mode: genai.FunctionCallingConfigModeAuto,
833 },
834 }
835 case fantasy.ToolChoiceRequired:
836 googleToolChoice = &genai.ToolConfig{
837 FunctionCallingConfig: &genai.FunctionCallingConfig{
838 Mode: genai.FunctionCallingConfigModeAny,
839 },
840 }
841 case fantasy.ToolChoiceNone:
842 googleToolChoice = &genai.ToolConfig{
843 FunctionCallingConfig: &genai.FunctionCallingConfig{
844 Mode: genai.FunctionCallingConfigModeNone,
845 },
846 }
847 default:
848 googleToolChoice = &genai.ToolConfig{
849 FunctionCallingConfig: &genai.FunctionCallingConfig{
850 Mode: genai.FunctionCallingConfigModeAny,
851 AllowedFunctionNames: []string{
852 string(*toolChoice),
853 },
854 },
855 }
856 }
857 return googleTools, googleToolChoice, warnings
858}
859
860func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
861 properties := make(map[string]*genai.Schema)
862
863 for name, param := range parameters {
864 properties[name] = convertToSchema(param)
865 }
866
867 return properties
868}
869
870func convertToSchema(param any) *genai.Schema {
871 schema := &genai.Schema{Type: genai.TypeString}
872
873 paramMap, ok := param.(map[string]any)
874 if !ok {
875 return schema
876 }
877
878 if desc, ok := paramMap["description"].(string); ok {
879 schema.Description = desc
880 }
881
882 typeVal, hasType := paramMap["type"]
883 if !hasType {
884 return schema
885 }
886
887 typeStr, ok := typeVal.(string)
888 if !ok {
889 return schema
890 }
891
892 schema.Type = mapJSONTypeToGoogle(typeStr)
893
894 switch typeStr {
895 case "array":
896 schema.Items = processArrayItems(paramMap)
897 case "object":
898 if props, ok := paramMap["properties"].(map[string]any); ok {
899 schema.Properties = convertSchemaProperties(props)
900 }
901 }
902
903 return schema
904}
905
906func processArrayItems(paramMap map[string]any) *genai.Schema {
907 items, ok := paramMap["items"].(map[string]any)
908 if !ok {
909 return nil
910 }
911
912 return convertToSchema(items)
913}
914
915func mapJSONTypeToGoogle(jsonType string) genai.Type {
916 switch jsonType {
917 case "string":
918 return genai.TypeString
919 case "number":
920 return genai.TypeNumber
921 case "integer":
922 return genai.TypeInteger
923 case "boolean":
924 return genai.TypeBoolean
925 case "array":
926 return genai.TypeArray
927 case "object":
928 return genai.TypeObject
929 default:
930 return genai.TypeString // Default to string for unknown types
931 }
932}
933
934func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
935 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
936 return nil, errors.New("no response from model")
937 }
938
939 var (
940 content []fantasy.Content
941 finishReason fantasy.FinishReason
942 hasToolCalls bool
943 candidate = response.Candidates[0]
944 )
945
946 for _, part := range candidate.Content.Parts {
947 switch {
948 case part.Text != "":
949 if part.Thought {
950 metadata := &ReasoningMetadata{
951 Signature: string(part.ThoughtSignature),
952 }
953 content = append(content, fantasy.ReasoningContent{Text: part.Text, ProviderMetadata: fantasy.ProviderMetadata{Name: metadata}})
954 } else {
955 content = append(content, fantasy.TextContent{Text: part.Text})
956 }
957 case part.FunctionCall != nil:
958 input, err := json.Marshal(part.FunctionCall.Args)
959 if err != nil {
960 return nil, err
961 }
962 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
963 content = append(content, fantasy.ToolCallContent{
964 ToolCallID: toolCallID,
965 ToolName: part.FunctionCall.Name,
966 Input: string(input),
967 ProviderExecuted: false,
968 })
969 hasToolCalls = true
970 default:
971 // Silently skip unknown part types instead of erroring
972 // This allows for forward compatibility with new part types
973 }
974 }
975
976 if hasToolCalls {
977 finishReason = fantasy.FinishReasonToolCalls
978 } else {
979 finishReason = mapFinishReason(candidate.FinishReason)
980 }
981
982 return &fantasy.Response{
983 Content: content,
984 Usage: mapUsage(response.UsageMetadata),
985 FinishReason: finishReason,
986 Warnings: warnings,
987 }, nil
988}
989
990func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
991 switch reason {
992 case genai.FinishReasonStop:
993 return fantasy.FinishReasonStop
994 case genai.FinishReasonMaxTokens:
995 return fantasy.FinishReasonLength
996 case genai.FinishReasonSafety,
997 genai.FinishReasonBlocklist,
998 genai.FinishReasonProhibitedContent,
999 genai.FinishReasonSPII,
1000 genai.FinishReasonImageSafety:
1001 return fantasy.FinishReasonContentFilter
1002 case genai.FinishReasonRecitation,
1003 genai.FinishReasonLanguage,
1004 genai.FinishReasonMalformedFunctionCall:
1005 return fantasy.FinishReasonError
1006 case genai.FinishReasonOther:
1007 return fantasy.FinishReasonOther
1008 default:
1009 return fantasy.FinishReasonUnknown
1010 }
1011}
1012
1013func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1014 return fantasy.Usage{
1015 InputTokens: int64(usage.PromptTokenCount),
1016 OutputTokens: int64(usage.CandidatesTokenCount),
1017 TotalTokens: int64(usage.TotalTokenCount),
1018 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1019 CacheCreationTokens: 0,
1020 CacheReadTokens: int64(usage.CachedContentTokenCount),
1021 }
1022}