1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "cloud.google.com/go/auth"
16 "github.com/charmbracelet/x/exp/slice"
17 "github.com/google/uuid"
18 "google.golang.org/genai"
19)
20
21// Name is the name of the Google provider.
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28// ToolCallIDFunc defines a function that generates a tool call ID.
29type ToolCallIDFunc = func() string
30
31type options struct {
32 apiKey string
33 name string
34 baseURL string
35 headers map[string]string
36 client *http.Client
37 backend genai.Backend
38 project string
39 location string
40 skipAuth bool
41 toolCallIDFunc ToolCallIDFunc
42}
43
44// Option defines a function that configures Google provider options.
45type Option = func(*options)
46
47// New creates a new Google provider with the given options.
48func New(opts ...Option) (fantasy.Provider, error) {
49 options := options{
50 headers: map[string]string{},
51 toolCallIDFunc: func() string {
52 return uuid.NewString()
53 },
54 }
55 for _, o := range opts {
56 o(&options)
57 }
58
59 options.name = cmp.Or(options.name, Name)
60
61 return &provider{
62 options: options,
63 }, nil
64}
65
66// WithBaseURL sets the base URL for the Google provider.
67func WithBaseURL(baseURL string) Option {
68 return func(o *options) {
69 o.baseURL = baseURL
70 }
71}
72
73// WithGeminiAPIKey sets the Gemini API key for the Google provider.
74func WithGeminiAPIKey(apiKey string) Option {
75 return func(o *options) {
76 o.backend = genai.BackendGeminiAPI
77 o.apiKey = apiKey
78 o.project = ""
79 o.location = ""
80 }
81}
82
83// WithVertex configures the Google provider to use Vertex AI.
84func WithVertex(project, location string) Option {
85 if project == "" || location == "" {
86 panic("project and location must be provided")
87 }
88 return func(o *options) {
89 o.backend = genai.BackendVertexAI
90 o.apiKey = ""
91 o.project = project
92 o.location = location
93 }
94}
95
96// WithSkipAuth configures whether to skip authentication for the Google provider.
97func WithSkipAuth(skipAuth bool) Option {
98 return func(o *options) {
99 o.skipAuth = skipAuth
100 }
101}
102
103// WithName sets the name for the Google provider.
104func WithName(name string) Option {
105 return func(o *options) {
106 o.name = name
107 }
108}
109
110// WithHeaders sets the headers for the Google provider.
111func WithHeaders(headers map[string]string) Option {
112 return func(o *options) {
113 maps.Copy(o.headers, headers)
114 }
115}
116
117// WithHTTPClient sets the HTTP client for the Google provider.
118func WithHTTPClient(client *http.Client) Option {
119 return func(o *options) {
120 o.client = client
121 }
122}
123
124// WithToolCallIDFunc sets the function that generates a tool call ID.
125func WithToolCallIDFunc(f ToolCallIDFunc) Option {
126 return func(o *options) {
127 o.toolCallIDFunc = f
128 }
129}
130
131func (*provider) Name() string {
132 return Name
133}
134
135type languageModel struct {
136 provider string
137 modelID string
138 client *genai.Client
139 providerOptions options
140}
141
142// LanguageModel implements fantasy.Provider.
143func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
144 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
145 p, err := anthropic.New(
146 anthropic.WithVertex(a.options.project, a.options.location),
147 anthropic.WithHTTPClient(a.options.client),
148 anthropic.WithSkipAuth(a.options.skipAuth),
149 )
150 if err != nil {
151 return nil, err
152 }
153 return p.LanguageModel(ctx, modelID)
154 }
155
156 cc := &genai.ClientConfig{
157 HTTPClient: a.options.client,
158 Backend: a.options.backend,
159 APIKey: a.options.apiKey,
160 Project: a.options.project,
161 Location: a.options.location,
162 }
163 if a.options.skipAuth {
164 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
165 } else if cc.Backend == genai.BackendVertexAI {
166 if err := cc.UseDefaultCredentials(); err != nil {
167 return nil, err
168 }
169 }
170
171 if a.options.baseURL != "" || len(a.options.headers) > 0 {
172 headers := http.Header{}
173 for k, v := range a.options.headers {
174 headers.Add(k, v)
175 }
176 cc.HTTPOptions = genai.HTTPOptions{
177 BaseURL: a.options.baseURL,
178 Headers: headers,
179 }
180 }
181 client, err := genai.NewClient(ctx, cc)
182 if err != nil {
183 return nil, err
184 }
185 return &languageModel{
186 modelID: modelID,
187 provider: a.options.name,
188 providerOptions: a.options,
189 client: client,
190 }, nil
191}
192
193func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
194 config := &genai.GenerateContentConfig{}
195
196 providerOptions := &ProviderOptions{}
197 if v, ok := call.ProviderOptions[Name]; ok {
198 providerOptions, ok = v.(*ProviderOptions)
199 if !ok {
200 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
201 }
202 }
203
204 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
205
206 if providerOptions.ThinkingConfig != nil {
207 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
208 *providerOptions.ThinkingConfig.IncludeThoughts &&
209 strings.HasPrefix(g.provider, "google.vertex.") {
210 warnings = append(warnings, fantasy.CallWarning{
211 Type: fantasy.CallWarningTypeOther,
212 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
213 "and might not be supported or could behave unexpectedly with the current Google provider " +
214 fmt.Sprintf("(%s)", g.provider),
215 })
216 }
217
218 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
219 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
220 warnings = append(warnings, fantasy.CallWarning{
221 Type: fantasy.CallWarningTypeOther,
222 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
223 })
224 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
225 }
226 }
227
228 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
229
230 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
231 if len(content) > 0 && content[0].Role == genai.RoleUser {
232 systemParts := []string{}
233 for _, sp := range systemInstructions.Parts {
234 systemParts = append(systemParts, sp.Text)
235 }
236 systemMsg := strings.Join(systemParts, "\n")
237 content[0].Parts = append([]*genai.Part{
238 {
239 Text: systemMsg + "\n\n",
240 },
241 }, content[0].Parts...)
242 systemInstructions = nil
243 }
244 }
245
246 config.SystemInstruction = systemInstructions
247
248 if call.MaxOutputTokens != nil {
249 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
250 }
251
252 if call.Temperature != nil {
253 tmp := float32(*call.Temperature)
254 config.Temperature = &tmp
255 }
256 if call.TopK != nil {
257 tmp := float32(*call.TopK)
258 config.TopK = &tmp
259 }
260 if call.TopP != nil {
261 tmp := float32(*call.TopP)
262 config.TopP = &tmp
263 }
264 if call.FrequencyPenalty != nil {
265 tmp := float32(*call.FrequencyPenalty)
266 config.FrequencyPenalty = &tmp
267 }
268 if call.PresencePenalty != nil {
269 tmp := float32(*call.PresencePenalty)
270 config.PresencePenalty = &tmp
271 }
272
273 if providerOptions.ThinkingConfig != nil {
274 config.ThinkingConfig = &genai.ThinkingConfig{}
275 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
276 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
277 }
278 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
279 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
280 config.ThinkingConfig.ThinkingBudget = &tmp
281 }
282 }
283 for _, safetySetting := range providerOptions.SafetySettings {
284 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
285 Category: genai.HarmCategory(safetySetting.Category),
286 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
287 })
288 }
289 if providerOptions.CachedContent != "" {
290 config.CachedContent = providerOptions.CachedContent
291 }
292
293 if len(call.Tools) > 0 {
294 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
295 config.ToolConfig = toolChoice
296 config.Tools = append(config.Tools, &genai.Tool{
297 FunctionDeclarations: tools,
298 })
299 warnings = append(warnings, toolWarnings...)
300 }
301
302 return config, content, warnings, nil
303}
304
305func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
306 var systemInstructions *genai.Content
307 var content []*genai.Content
308 var warnings []fantasy.CallWarning
309
310 finishedSystemBlock := false
311 for _, msg := range prompt {
312 switch msg.Role {
313 case fantasy.MessageRoleSystem:
314 if finishedSystemBlock {
315 // skip multiple system messages that are separated by user/assistant messages
316 // TODO: see if we need to send error here?
317 continue
318 }
319 finishedSystemBlock = true
320
321 var systemMessages []string
322 for _, part := range msg.Content {
323 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
324 if !ok || text.Text == "" {
325 continue
326 }
327 systemMessages = append(systemMessages, text.Text)
328 }
329 if len(systemMessages) > 0 {
330 systemInstructions = &genai.Content{
331 Parts: []*genai.Part{
332 {
333 Text: strings.Join(systemMessages, "\n"),
334 },
335 },
336 }
337 }
338 case fantasy.MessageRoleUser:
339 var parts []*genai.Part
340 for _, part := range msg.Content {
341 switch part.GetType() {
342 case fantasy.ContentTypeText:
343 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
344 if !ok || text.Text == "" {
345 continue
346 }
347 parts = append(parts, &genai.Part{
348 Text: text.Text,
349 })
350 case fantasy.ContentTypeFile:
351 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
352 if !ok {
353 continue
354 }
355 parts = append(parts, &genai.Part{
356 InlineData: &genai.Blob{
357 Data: file.Data,
358 MIMEType: file.MediaType,
359 },
360 })
361 }
362 }
363 if len(parts) > 0 {
364 content = append(content, &genai.Content{
365 Role: genai.RoleUser,
366 Parts: parts,
367 })
368 }
369 case fantasy.MessageRoleAssistant:
370 var parts []*genai.Part
371 // INFO: (kujtim) this is kind of a hacky way to include thinking for google
372 // weirdly thinking needs to be included in a function call
373 var signature []byte
374 for _, part := range msg.Content {
375 switch part.GetType() {
376 case fantasy.ContentTypeText:
377 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
378 if !ok || text.Text == "" {
379 continue
380 }
381 parts = append(parts, &genai.Part{
382 Text: text.Text,
383 })
384 case fantasy.ContentTypeToolCall:
385 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
386 if !ok {
387 continue
388 }
389
390 var result map[string]any
391 err := json.Unmarshal([]byte(toolCall.Input), &result)
392 if err != nil {
393 continue
394 }
395 parts = append(parts, &genai.Part{
396 FunctionCall: &genai.FunctionCall{
397 ID: toolCall.ToolCallID,
398 Name: toolCall.ToolName,
399 Args: result,
400 },
401 ThoughtSignature: signature,
402 })
403 // reset
404 signature = nil
405 case fantasy.ContentTypeReasoning:
406 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
407 if !ok {
408 continue
409 }
410 metadata, ok := reasoning.ProviderOptions[Name]
411 if !ok {
412 continue
413 }
414 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
415 if !ok {
416 continue
417 }
418 if !ok || reasoningMetadata.Signature == "" {
419 continue
420 }
421 signature = []byte(reasoningMetadata.Signature)
422 }
423 }
424 if len(parts) > 0 {
425 content = append(content, &genai.Content{
426 Role: genai.RoleModel,
427 Parts: parts,
428 })
429 }
430 case fantasy.MessageRoleTool:
431 var parts []*genai.Part
432 for _, part := range msg.Content {
433 switch part.GetType() {
434 case fantasy.ContentTypeToolResult:
435 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
436 if !ok {
437 continue
438 }
439 var toolCall fantasy.ToolCallPart
440 for _, m := range prompt {
441 if m.Role == fantasy.MessageRoleAssistant {
442 for _, content := range m.Content {
443 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
444 if !ok {
445 continue
446 }
447 if tc.ToolCallID == result.ToolCallID {
448 toolCall = tc
449 break
450 }
451 }
452 }
453 }
454 switch result.Output.GetType() {
455 case fantasy.ToolResultContentTypeText:
456 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
457 if !ok {
458 continue
459 }
460 response := map[string]any{"result": content.Text}
461 parts = append(parts, &genai.Part{
462 FunctionResponse: &genai.FunctionResponse{
463 ID: result.ToolCallID,
464 Response: response,
465 Name: toolCall.ToolName,
466 },
467 })
468
469 case fantasy.ToolResultContentTypeError:
470 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
471 if !ok {
472 continue
473 }
474 response := map[string]any{"result": content.Error.Error()}
475 parts = append(parts, &genai.Part{
476 FunctionResponse: &genai.FunctionResponse{
477 ID: result.ToolCallID,
478 Response: response,
479 Name: toolCall.ToolName,
480 },
481 })
482 }
483 }
484 }
485 if len(parts) > 0 {
486 content = append(content, &genai.Content{
487 Role: genai.RoleUser,
488 Parts: parts,
489 })
490 }
491 default:
492 panic("unsupported message role: " + msg.Role)
493 }
494 }
495 return systemInstructions, content, warnings
496}
497
498// Generate implements fantasy.LanguageModel.
499func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
500 config, contents, warnings, err := g.prepareParams(call)
501 if err != nil {
502 return nil, err
503 }
504
505 lastMessage, history, ok := slice.Pop(contents)
506 if !ok {
507 return nil, errors.New("no messages to send")
508 }
509
510 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
511 if err != nil {
512 return nil, err
513 }
514
515 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
516 if err != nil {
517 return nil, err
518 }
519
520 return g.mapResponse(response, warnings)
521}
522
523// Model implements fantasy.LanguageModel.
524func (g *languageModel) Model() string {
525 return g.modelID
526}
527
528// Provider implements fantasy.LanguageModel.
529func (g *languageModel) Provider() string {
530 return g.provider
531}
532
533// Stream implements fantasy.LanguageModel.
534func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
535 config, contents, warnings, err := g.prepareParams(call)
536 if err != nil {
537 return nil, err
538 }
539
540 lastMessage, history, ok := slice.Pop(contents)
541 if !ok {
542 return nil, errors.New("no messages to send")
543 }
544
545 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
546 if err != nil {
547 return nil, err
548 }
549
550 return func(yield func(fantasy.StreamPart) bool) {
551 if len(warnings) > 0 {
552 if !yield(fantasy.StreamPart{
553 Type: fantasy.StreamPartTypeWarnings,
554 Warnings: warnings,
555 }) {
556 return
557 }
558 }
559
560 var currentContent string
561 var toolCalls []fantasy.ToolCallContent
562 var isActiveText bool
563 var isActiveReasoning bool
564 var blockCounter int
565 var currentTextBlockID string
566 var currentReasoningBlockID string
567 var usage *fantasy.Usage
568 var lastFinishReason fantasy.FinishReason
569
570 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
571 if err != nil {
572 yield(fantasy.StreamPart{
573 Type: fantasy.StreamPartTypeError,
574 Error: err,
575 })
576 return
577 }
578
579 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
580 for _, part := range resp.Candidates[0].Content.Parts {
581 switch {
582 case part.Text != "":
583 delta := part.Text
584 if delta != "" {
585 // Check if this is a reasoning/thought part
586 if part.Thought {
587 // End any active text block before starting reasoning
588 if isActiveText {
589 isActiveText = false
590 if !yield(fantasy.StreamPart{
591 Type: fantasy.StreamPartTypeTextEnd,
592 ID: currentTextBlockID,
593 }) {
594 return
595 }
596 }
597
598 // Start new reasoning block if not already active
599 if !isActiveReasoning {
600 isActiveReasoning = true
601 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
602 blockCounter++
603 if !yield(fantasy.StreamPart{
604 Type: fantasy.StreamPartTypeReasoningStart,
605 ID: currentReasoningBlockID,
606 }) {
607 return
608 }
609 }
610
611 if !yield(fantasy.StreamPart{
612 Type: fantasy.StreamPartTypeReasoningDelta,
613 ID: currentReasoningBlockID,
614 Delta: delta,
615 }) {
616 return
617 }
618 } else {
619 // Regular text part
620 // End any active reasoning block before starting text
621 if isActiveReasoning {
622 isActiveReasoning = false
623 metadata := &ReasoningMetadata{
624 Signature: string(part.ThoughtSignature),
625 }
626 if !yield(fantasy.StreamPart{
627 Type: fantasy.StreamPartTypeReasoningEnd,
628 ID: currentReasoningBlockID,
629 ProviderMetadata: fantasy.ProviderMetadata{
630 Name: metadata,
631 },
632 }) {
633 return
634 }
635 }
636
637 // Start new text block if not already active
638 if !isActiveText {
639 isActiveText = true
640 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
641 blockCounter++
642 if !yield(fantasy.StreamPart{
643 Type: fantasy.StreamPartTypeTextStart,
644 ID: currentTextBlockID,
645 }) {
646 return
647 }
648 }
649
650 if !yield(fantasy.StreamPart{
651 Type: fantasy.StreamPartTypeTextDelta,
652 ID: currentTextBlockID,
653 Delta: delta,
654 }) {
655 return
656 }
657 currentContent += delta
658 }
659 }
660 case part.FunctionCall != nil:
661 // End any active text or reasoning blocks
662 if isActiveText {
663 isActiveText = false
664 if !yield(fantasy.StreamPart{
665 Type: fantasy.StreamPartTypeTextEnd,
666 ID: currentTextBlockID,
667 }) {
668 return
669 }
670 }
671 if isActiveReasoning {
672 isActiveReasoning = false
673
674 metadata := &ReasoningMetadata{
675 Signature: string(part.ThoughtSignature),
676 }
677 if !yield(fantasy.StreamPart{
678 Type: fantasy.StreamPartTypeReasoningEnd,
679 ID: currentReasoningBlockID,
680 ProviderMetadata: fantasy.ProviderMetadata{
681 Name: metadata,
682 },
683 }) {
684 return
685 }
686 }
687
688 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
689
690 args, err := json.Marshal(part.FunctionCall.Args)
691 if err != nil {
692 yield(fantasy.StreamPart{
693 Type: fantasy.StreamPartTypeError,
694 Error: err,
695 })
696 return
697 }
698
699 if !yield(fantasy.StreamPart{
700 Type: fantasy.StreamPartTypeToolInputStart,
701 ID: toolCallID,
702 ToolCallName: part.FunctionCall.Name,
703 }) {
704 return
705 }
706
707 if !yield(fantasy.StreamPart{
708 Type: fantasy.StreamPartTypeToolInputDelta,
709 ID: toolCallID,
710 Delta: string(args),
711 }) {
712 return
713 }
714
715 if !yield(fantasy.StreamPart{
716 Type: fantasy.StreamPartTypeToolInputEnd,
717 ID: toolCallID,
718 }) {
719 return
720 }
721
722 if !yield(fantasy.StreamPart{
723 Type: fantasy.StreamPartTypeToolCall,
724 ID: toolCallID,
725 ToolCallName: part.FunctionCall.Name,
726 ToolCallInput: string(args),
727 ProviderExecuted: false,
728 }) {
729 return
730 }
731
732 toolCalls = append(toolCalls, fantasy.ToolCallContent{
733 ToolCallID: toolCallID,
734 ToolName: part.FunctionCall.Name,
735 Input: string(args),
736 ProviderExecuted: false,
737 })
738 }
739 }
740 }
741
742 if resp.UsageMetadata != nil {
743 currentUsage := mapUsage(resp.UsageMetadata)
744 // if first usage chunk
745 if usage == nil {
746 usage = ¤tUsage
747 } else {
748 usage.OutputTokens += currentUsage.OutputTokens
749 usage.ReasoningTokens += currentUsage.ReasoningTokens
750 usage.CacheReadTokens += currentUsage.CacheReadTokens
751 }
752 }
753
754 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
755 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
756 }
757 }
758
759 // Close any open blocks before finishing
760 if isActiveText {
761 if !yield(fantasy.StreamPart{
762 Type: fantasy.StreamPartTypeTextEnd,
763 ID: currentTextBlockID,
764 }) {
765 return
766 }
767 }
768 if isActiveReasoning {
769 if !yield(fantasy.StreamPart{
770 Type: fantasy.StreamPartTypeReasoningEnd,
771 ID: currentReasoningBlockID,
772 }) {
773 return
774 }
775 }
776
777 finishReason := lastFinishReason
778 if len(toolCalls) > 0 {
779 finishReason = fantasy.FinishReasonToolCalls
780 } else if finishReason == "" {
781 finishReason = fantasy.FinishReasonStop
782 }
783
784 yield(fantasy.StreamPart{
785 Type: fantasy.StreamPartTypeFinish,
786 Usage: *usage,
787 FinishReason: finishReason,
788 })
789 }, nil
790}
791
792func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
793 for _, tool := range tools {
794 if tool.GetType() == fantasy.ToolTypeFunction {
795 ft, ok := tool.(fantasy.FunctionTool)
796 if !ok {
797 continue
798 }
799
800 required := []string{}
801 var properties map[string]any
802 if props, ok := ft.InputSchema["properties"]; ok {
803 properties, _ = props.(map[string]any)
804 }
805 if req, ok := ft.InputSchema["required"]; ok {
806 if reqArr, ok := req.([]string); ok {
807 required = reqArr
808 }
809 }
810 declaration := &genai.FunctionDeclaration{
811 Name: ft.Name,
812 Description: ft.Description,
813 Parameters: &genai.Schema{
814 Type: genai.TypeObject,
815 Properties: convertSchemaProperties(properties),
816 Required: required,
817 },
818 }
819 googleTools = append(googleTools, declaration)
820 continue
821 }
822 // TODO: handle provider tool calls
823 warnings = append(warnings, fantasy.CallWarning{
824 Type: fantasy.CallWarningTypeUnsupportedTool,
825 Tool: tool,
826 Message: "tool is not supported",
827 })
828 }
829 if toolChoice == nil {
830 return googleTools, googleToolChoice, warnings
831 }
832 switch *toolChoice {
833 case fantasy.ToolChoiceAuto:
834 googleToolChoice = &genai.ToolConfig{
835 FunctionCallingConfig: &genai.FunctionCallingConfig{
836 Mode: genai.FunctionCallingConfigModeAuto,
837 },
838 }
839 case fantasy.ToolChoiceRequired:
840 googleToolChoice = &genai.ToolConfig{
841 FunctionCallingConfig: &genai.FunctionCallingConfig{
842 Mode: genai.FunctionCallingConfigModeAny,
843 },
844 }
845 case fantasy.ToolChoiceNone:
846 googleToolChoice = &genai.ToolConfig{
847 FunctionCallingConfig: &genai.FunctionCallingConfig{
848 Mode: genai.FunctionCallingConfigModeNone,
849 },
850 }
851 default:
852 googleToolChoice = &genai.ToolConfig{
853 FunctionCallingConfig: &genai.FunctionCallingConfig{
854 Mode: genai.FunctionCallingConfigModeAny,
855 AllowedFunctionNames: []string{
856 string(*toolChoice),
857 },
858 },
859 }
860 }
861 return googleTools, googleToolChoice, warnings
862}
863
864func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
865 properties := make(map[string]*genai.Schema)
866
867 for name, param := range parameters {
868 properties[name] = convertToSchema(param)
869 }
870
871 return properties
872}
873
874func convertToSchema(param any) *genai.Schema {
875 schema := &genai.Schema{Type: genai.TypeString}
876
877 paramMap, ok := param.(map[string]any)
878 if !ok {
879 return schema
880 }
881
882 if desc, ok := paramMap["description"].(string); ok {
883 schema.Description = desc
884 }
885
886 typeVal, hasType := paramMap["type"]
887 if !hasType {
888 return schema
889 }
890
891 typeStr, ok := typeVal.(string)
892 if !ok {
893 return schema
894 }
895
896 schema.Type = mapJSONTypeToGoogle(typeStr)
897
898 switch typeStr {
899 case "array":
900 schema.Items = processArrayItems(paramMap)
901 case "object":
902 if props, ok := paramMap["properties"].(map[string]any); ok {
903 schema.Properties = convertSchemaProperties(props)
904 }
905 }
906
907 return schema
908}
909
910func processArrayItems(paramMap map[string]any) *genai.Schema {
911 items, ok := paramMap["items"].(map[string]any)
912 if !ok {
913 return nil
914 }
915
916 return convertToSchema(items)
917}
918
919func mapJSONTypeToGoogle(jsonType string) genai.Type {
920 switch jsonType {
921 case "string":
922 return genai.TypeString
923 case "number":
924 return genai.TypeNumber
925 case "integer":
926 return genai.TypeInteger
927 case "boolean":
928 return genai.TypeBoolean
929 case "array":
930 return genai.TypeArray
931 case "object":
932 return genai.TypeObject
933 default:
934 return genai.TypeString // Default to string for unknown types
935 }
936}
937
938func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
939 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
940 return nil, errors.New("no response from model")
941 }
942
943 var (
944 content []fantasy.Content
945 finishReason fantasy.FinishReason
946 hasToolCalls bool
947 candidate = response.Candidates[0]
948 )
949
950 for _, part := range candidate.Content.Parts {
951 switch {
952 case part.Text != "":
953 if part.Thought {
954 metadata := &ReasoningMetadata{
955 Signature: string(part.ThoughtSignature),
956 }
957 content = append(content, fantasy.ReasoningContent{Text: part.Text, ProviderMetadata: fantasy.ProviderMetadata{Name: metadata}})
958 } else {
959 content = append(content, fantasy.TextContent{Text: part.Text})
960 }
961 case part.FunctionCall != nil:
962 input, err := json.Marshal(part.FunctionCall.Args)
963 if err != nil {
964 return nil, err
965 }
966 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
967 content = append(content, fantasy.ToolCallContent{
968 ToolCallID: toolCallID,
969 ToolName: part.FunctionCall.Name,
970 Input: string(input),
971 ProviderExecuted: false,
972 })
973 hasToolCalls = true
974 default:
975 // Silently skip unknown part types instead of erroring
976 // This allows for forward compatibility with new part types
977 }
978 }
979
980 if hasToolCalls {
981 finishReason = fantasy.FinishReasonToolCalls
982 } else {
983 finishReason = mapFinishReason(candidate.FinishReason)
984 }
985
986 return &fantasy.Response{
987 Content: content,
988 Usage: mapUsage(response.UsageMetadata),
989 FinishReason: finishReason,
990 Warnings: warnings,
991 }, nil
992}
993
994func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
995 switch reason {
996 case genai.FinishReasonStop:
997 return fantasy.FinishReasonStop
998 case genai.FinishReasonMaxTokens:
999 return fantasy.FinishReasonLength
1000 case genai.FinishReasonSafety,
1001 genai.FinishReasonBlocklist,
1002 genai.FinishReasonProhibitedContent,
1003 genai.FinishReasonSPII,
1004 genai.FinishReasonImageSafety:
1005 return fantasy.FinishReasonContentFilter
1006 case genai.FinishReasonRecitation,
1007 genai.FinishReasonLanguage,
1008 genai.FinishReasonMalformedFunctionCall:
1009 return fantasy.FinishReasonError
1010 case genai.FinishReasonOther:
1011 return fantasy.FinishReasonOther
1012 default:
1013 return fantasy.FinishReasonUnknown
1014 }
1015}
1016
1017func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1018 return fantasy.Usage{
1019 InputTokens: int64(usage.PromptTokenCount),
1020 OutputTokens: int64(usage.CandidatesTokenCount),
1021 TotalTokens: int64(usage.TotalTokenCount),
1022 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1023 CacheCreationTokens: 0,
1024 CacheReadTokens: int64(usage.CachedContentTokenCount),
1025 }
1026}