1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "cloud.google.com/go/auth"
16 "github.com/charmbracelet/go-genai"
17 "github.com/charmbracelet/x/exp/slice"
18 "github.com/google/uuid"
19)
20
21// Name is the name of the Google provider.
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28// ToolCallIDFunc defines a function that generates a tool call ID.
29type ToolCallIDFunc = func() string
30
31type options struct {
32 apiKey string
33 name string
34 baseURL string
35 headers map[string]string
36 client *http.Client
37 backend genai.Backend
38 project string
39 location string
40 skipAuth bool
41 generateID ToolCallIDFunc
42}
43
44// Option defines a function that configures Google provider options.
45type Option = func(*options)
46
47// New creates a new Google provider with the given options.
48func New(opts ...Option) (fantasy.Provider, error) {
49 options := options{
50 headers: map[string]string{},
51 generateID: func() string {
52 return uuid.NewString()
53 },
54 }
55 for _, o := range opts {
56 o(&options)
57 }
58
59 options.name = cmp.Or(options.name, Name)
60
61 return &provider{
62 options: options,
63 }, nil
64}
65
66// WithBaseURL sets the base URL for the Google provider.
67func WithBaseURL(baseURL string) Option {
68 return func(o *options) {
69 o.baseURL = baseURL
70 }
71}
72
73// WithGeminiAPIKey sets the Gemini API key for the Google provider.
74func WithGeminiAPIKey(apiKey string) Option {
75 return func(o *options) {
76 o.backend = genai.BackendGeminiAPI
77 o.apiKey = apiKey
78 o.project = ""
79 o.location = ""
80 }
81}
82
83// WithVertex configures the Google provider to use Vertex AI.
84func WithVertex(project, location string) Option {
85 if project == "" || location == "" {
86 panic("project and location must be provided")
87 }
88 return func(o *options) {
89 o.backend = genai.BackendVertexAI
90 o.apiKey = ""
91 o.project = project
92 o.location = location
93 }
94}
95
96// WithSkipAuth configures whether to skip authentication for the Google provider.
97func WithSkipAuth(skipAuth bool) Option {
98 return func(o *options) {
99 o.skipAuth = skipAuth
100 }
101}
102
103// WithName sets the name for the Google provider.
104func WithName(name string) Option {
105 return func(o *options) {
106 o.name = name
107 }
108}
109
110// WithHeaders sets the headers for the Google provider.
111func WithHeaders(headers map[string]string) Option {
112 return func(o *options) {
113 maps.Copy(o.headers, headers)
114 }
115}
116
117// WithHTTPClient sets the HTTP client for the Google provider.
118func WithHTTPClient(client *http.Client) Option {
119 return func(o *options) {
120 o.client = client
121 }
122}
123
124// WithToolCallIDFunc sets the function that generates a tool call ID.
125// WithToolCallIDFunc sets the function that generates a tool call ID.
126func WithToolCallIDFunc(f ToolCallIDFunc) Option {
127 return func(o *options) {
128 o.generateID = f
129 }
130}
131
132func (*provider) Name() string {
133 return Name
134}
135
136type languageModel struct {
137 provider string
138 modelID string
139 client *genai.Client
140 providerOptions options
141}
142
143// LanguageModel implements fantasy.Provider.
144func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
145 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
146 p, err := anthropic.New(
147 anthropic.WithVertex(a.options.project, a.options.location),
148 anthropic.WithHTTPClient(a.options.client),
149 anthropic.WithSkipAuth(a.options.skipAuth),
150 )
151 if err != nil {
152 return nil, err
153 }
154 return p.LanguageModel(ctx, modelID)
155 }
156
157 cc := &genai.ClientConfig{
158 HTTPClient: a.options.client,
159 Backend: a.options.backend,
160 APIKey: a.options.apiKey,
161 Project: a.options.project,
162 Location: a.options.location,
163 }
164 if a.options.skipAuth {
165 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
166 }
167
168 if a.options.baseURL != "" || len(a.options.headers) > 0 {
169 headers := http.Header{}
170 for k, v := range a.options.headers {
171 headers.Add(k, v)
172 }
173 cc.HTTPOptions = genai.HTTPOptions{
174 BaseURL: a.options.baseURL,
175 Headers: headers,
176 }
177 }
178 client, err := genai.NewClient(ctx, cc)
179 if err != nil {
180 return nil, err
181 }
182 return &languageModel{
183 modelID: modelID,
184 provider: a.options.name,
185 providerOptions: a.options,
186 client: client,
187 }, nil
188}
189
190func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
191 config := &genai.GenerateContentConfig{}
192
193 providerOptions := &ProviderOptions{}
194 if v, ok := call.ProviderOptions[Name]; ok {
195 providerOptions, ok = v.(*ProviderOptions)
196 if !ok {
197 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
198 }
199 }
200
201 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
202
203 if providerOptions.ThinkingConfig != nil {
204 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
205 *providerOptions.ThinkingConfig.IncludeThoughts &&
206 strings.HasPrefix(g.provider, "google.vertex.") {
207 warnings = append(warnings, fantasy.CallWarning{
208 Type: fantasy.CallWarningTypeOther,
209 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
210 "and might not be supported or could behave unexpectedly with the current Google provider " +
211 fmt.Sprintf("(%s)", g.provider),
212 })
213 }
214
215 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
216 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
217 warnings = append(warnings, fantasy.CallWarning{
218 Type: fantasy.CallWarningTypeOther,
219 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
220 })
221 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
222 }
223 }
224
225 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
226
227 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
228 if len(content) > 0 && content[0].Role == genai.RoleUser {
229 systemParts := []string{}
230 for _, sp := range systemInstructions.Parts {
231 systemParts = append(systemParts, sp.Text)
232 }
233 systemMsg := strings.Join(systemParts, "\n")
234 content[0].Parts = append([]*genai.Part{
235 {
236 Text: systemMsg + "\n\n",
237 },
238 }, content[0].Parts...)
239 systemInstructions = nil
240 }
241 }
242
243 config.SystemInstruction = systemInstructions
244
245 if call.MaxOutputTokens != nil {
246 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
247 }
248
249 if call.Temperature != nil {
250 tmp := float32(*call.Temperature)
251 config.Temperature = &tmp
252 }
253 if call.TopK != nil {
254 tmp := float32(*call.TopK)
255 config.TopK = &tmp
256 }
257 if call.TopP != nil {
258 tmp := float32(*call.TopP)
259 config.TopP = &tmp
260 }
261 if call.FrequencyPenalty != nil {
262 tmp := float32(*call.FrequencyPenalty)
263 config.FrequencyPenalty = &tmp
264 }
265 if call.PresencePenalty != nil {
266 tmp := float32(*call.PresencePenalty)
267 config.PresencePenalty = &tmp
268 }
269
270 if providerOptions.ThinkingConfig != nil {
271 config.ThinkingConfig = &genai.ThinkingConfig{}
272 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
273 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
274 }
275 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
276 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
277 config.ThinkingConfig.ThinkingBudget = &tmp
278 }
279 }
280 for _, safetySetting := range providerOptions.SafetySettings {
281 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
282 Category: genai.HarmCategory(safetySetting.Category),
283 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
284 })
285 }
286 if providerOptions.CachedContent != "" {
287 config.CachedContent = providerOptions.CachedContent
288 }
289
290 if len(call.Tools) > 0 {
291 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
292 config.ToolConfig = toolChoice
293 config.Tools = append(config.Tools, &genai.Tool{
294 FunctionDeclarations: tools,
295 })
296 warnings = append(warnings, toolWarnings...)
297 }
298
299 return config, content, warnings, nil
300}
301
302func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
303 var systemInstructions *genai.Content
304 var content []*genai.Content
305 var warnings []fantasy.CallWarning
306
307 finishedSystemBlock := false
308 for _, msg := range prompt {
309 switch msg.Role {
310 case fantasy.MessageRoleSystem:
311 if finishedSystemBlock {
312 // skip multiple system messages that are separated by user/assistant messages
313 // TODO: see if we need to send error here?
314 continue
315 }
316 finishedSystemBlock = true
317
318 var systemMessages []string
319 for _, part := range msg.Content {
320 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
321 if !ok || text.Text == "" {
322 continue
323 }
324 systemMessages = append(systemMessages, text.Text)
325 }
326 if len(systemMessages) > 0 {
327 systemInstructions = &genai.Content{
328 Parts: []*genai.Part{
329 {
330 Text: strings.Join(systemMessages, "\n"),
331 },
332 },
333 }
334 }
335 case fantasy.MessageRoleUser:
336 var parts []*genai.Part
337 for _, part := range msg.Content {
338 switch part.GetType() {
339 case fantasy.ContentTypeText:
340 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
341 if !ok || text.Text == "" {
342 continue
343 }
344 parts = append(parts, &genai.Part{
345 Text: text.Text,
346 })
347 case fantasy.ContentTypeFile:
348 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
349 if !ok {
350 continue
351 }
352 parts = append(parts, &genai.Part{
353 InlineData: &genai.Blob{
354 Data: file.Data,
355 MIMEType: file.MediaType,
356 },
357 })
358 }
359 }
360 if len(parts) > 0 {
361 content = append(content, &genai.Content{
362 Role: genai.RoleUser,
363 Parts: parts,
364 })
365 }
366 case fantasy.MessageRoleAssistant:
367 var parts []*genai.Part
368 // INFO: (kujtim) this is kind of a hacky way to include thinking for google
369 // weirdly thinking needs to be included in a function call
370 var signature []byte
371 for _, part := range msg.Content {
372 switch part.GetType() {
373 case fantasy.ContentTypeText:
374 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
375 if !ok || text.Text == "" {
376 continue
377 }
378 parts = append(parts, &genai.Part{
379 Text: text.Text,
380 })
381 case fantasy.ContentTypeToolCall:
382 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
383 if !ok {
384 continue
385 }
386
387 var result map[string]any
388 err := json.Unmarshal([]byte(toolCall.Input), &result)
389 if err != nil {
390 continue
391 }
392 parts = append(parts, &genai.Part{
393 FunctionCall: &genai.FunctionCall{
394 ID: toolCall.ToolCallID,
395 Name: toolCall.ToolName,
396 Args: result,
397 },
398 ThoughtSignature: signature,
399 })
400 // reset
401 signature = nil
402 case fantasy.ContentTypeReasoning:
403 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
404 if !ok {
405 continue
406 }
407 metadata, ok := reasoning.ProviderOptions[Name]
408 if !ok {
409 continue
410 }
411 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
412 if !ok {
413 continue
414 }
415 if !ok || reasoningMetadata.Signature == "" {
416 continue
417 }
418 signature = []byte(reasoningMetadata.Signature)
419 }
420 }
421 if len(parts) > 0 {
422 content = append(content, &genai.Content{
423 Role: genai.RoleModel,
424 Parts: parts,
425 })
426 }
427 case fantasy.MessageRoleTool:
428 var parts []*genai.Part
429 for _, part := range msg.Content {
430 switch part.GetType() {
431 case fantasy.ContentTypeToolResult:
432 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
433 if !ok {
434 continue
435 }
436 var toolCall fantasy.ToolCallPart
437 for _, m := range prompt {
438 if m.Role == fantasy.MessageRoleAssistant {
439 for _, content := range m.Content {
440 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
441 if !ok {
442 continue
443 }
444 if tc.ToolCallID == result.ToolCallID {
445 toolCall = tc
446 break
447 }
448 }
449 }
450 }
451 switch result.Output.GetType() {
452 case fantasy.ToolResultContentTypeText:
453 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
454 if !ok {
455 continue
456 }
457 response := map[string]any{"result": content.Text}
458 parts = append(parts, &genai.Part{
459 FunctionResponse: &genai.FunctionResponse{
460 ID: result.ToolCallID,
461 Response: response,
462 Name: toolCall.ToolName,
463 },
464 })
465
466 case fantasy.ToolResultContentTypeError:
467 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
468 if !ok {
469 continue
470 }
471 response := map[string]any{"result": content.Error.Error()}
472 parts = append(parts, &genai.Part{
473 FunctionResponse: &genai.FunctionResponse{
474 ID: result.ToolCallID,
475 Response: response,
476 Name: toolCall.ToolName,
477 },
478 })
479 }
480 }
481 }
482 if len(parts) > 0 {
483 content = append(content, &genai.Content{
484 Role: genai.RoleUser,
485 Parts: parts,
486 })
487 }
488 default:
489 panic("unsupported message role: " + msg.Role)
490 }
491 }
492 return systemInstructions, content, warnings
493}
494
495// Generate implements fantasy.LanguageModel.
496func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
497 config, contents, warnings, err := g.prepareParams(call)
498 if err != nil {
499 return nil, err
500 }
501
502 lastMessage, history, ok := slice.Pop(contents)
503 if !ok {
504 return nil, errors.New("no messages to send")
505 }
506
507 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
508 if err != nil {
509 return nil, err
510 }
511
512 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
513 if err != nil {
514 return nil, err
515 }
516
517 return g.mapResponse(response, warnings)
518}
519
520// Model implements fantasy.LanguageModel.
521func (g *languageModel) Model() string {
522 return g.modelID
523}
524
525// Provider implements fantasy.LanguageModel.
526func (g *languageModel) Provider() string {
527 return g.provider
528}
529
530// Stream implements fantasy.LanguageModel.
531func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
532 config, contents, warnings, err := g.prepareParams(call)
533 if err != nil {
534 return nil, err
535 }
536
537 lastMessage, history, ok := slice.Pop(contents)
538 if !ok {
539 return nil, errors.New("no messages to send")
540 }
541
542 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
543 if err != nil {
544 return nil, err
545 }
546
547 return func(yield func(fantasy.StreamPart) bool) {
548 if len(warnings) > 0 {
549 if !yield(fantasy.StreamPart{
550 Type: fantasy.StreamPartTypeWarnings,
551 Warnings: warnings,
552 }) {
553 return
554 }
555 }
556
557 var currentContent string
558 var toolCalls []fantasy.ToolCallContent
559 var isActiveText bool
560 var isActiveReasoning bool
561 var blockCounter int
562 var currentTextBlockID string
563 var currentReasoningBlockID string
564 var usage *fantasy.Usage
565 var lastFinishReason fantasy.FinishReason
566
567 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
568 if err != nil {
569 yield(fantasy.StreamPart{
570 Type: fantasy.StreamPartTypeError,
571 Error: err,
572 })
573 return
574 }
575
576 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
577 for _, part := range resp.Candidates[0].Content.Parts {
578 switch {
579 case part.Text != "":
580 delta := part.Text
581 if delta != "" {
582 // Check if this is a reasoning/thought part
583 if part.Thought {
584 // End any active text block before starting reasoning
585 if isActiveText {
586 isActiveText = false
587 if !yield(fantasy.StreamPart{
588 Type: fantasy.StreamPartTypeTextEnd,
589 ID: currentTextBlockID,
590 }) {
591 return
592 }
593 }
594
595 // Start new reasoning block if not already active
596 if !isActiveReasoning {
597 isActiveReasoning = true
598 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
599 blockCounter++
600 if !yield(fantasy.StreamPart{
601 Type: fantasy.StreamPartTypeReasoningStart,
602 ID: currentReasoningBlockID,
603 }) {
604 return
605 }
606 }
607
608 if !yield(fantasy.StreamPart{
609 Type: fantasy.StreamPartTypeReasoningDelta,
610 ID: currentReasoningBlockID,
611 Delta: delta,
612 }) {
613 return
614 }
615 } else {
616 // Regular text part
617 // End any active reasoning block before starting text
618 if isActiveReasoning {
619 isActiveReasoning = false
620 metadata := &ReasoningMetadata{
621 Signature: string(part.ThoughtSignature),
622 }
623 if !yield(fantasy.StreamPart{
624 Type: fantasy.StreamPartTypeReasoningEnd,
625 ID: currentReasoningBlockID,
626 ProviderMetadata: fantasy.ProviderMetadata{
627 Name: metadata,
628 },
629 }) {
630 return
631 }
632 }
633
634 // Start new text block if not already active
635 if !isActiveText {
636 isActiveText = true
637 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
638 blockCounter++
639 if !yield(fantasy.StreamPart{
640 Type: fantasy.StreamPartTypeTextStart,
641 ID: currentTextBlockID,
642 }) {
643 return
644 }
645 }
646
647 if !yield(fantasy.StreamPart{
648 Type: fantasy.StreamPartTypeTextDelta,
649 ID: currentTextBlockID,
650 Delta: delta,
651 }) {
652 return
653 }
654 currentContent += delta
655 }
656 }
657 case part.FunctionCall != nil:
658 // End any active text or reasoning blocks
659 if isActiveText {
660 isActiveText = false
661 if !yield(fantasy.StreamPart{
662 Type: fantasy.StreamPartTypeTextEnd,
663 ID: currentTextBlockID,
664 }) {
665 return
666 }
667 }
668 if isActiveReasoning {
669 isActiveReasoning = false
670
671 metadata := &ReasoningMetadata{
672 Signature: string(part.ThoughtSignature),
673 }
674 if !yield(fantasy.StreamPart{
675 Type: fantasy.StreamPartTypeReasoningEnd,
676 ID: currentReasoningBlockID,
677 ProviderMetadata: fantasy.ProviderMetadata{
678 Name: metadata,
679 },
680 }) {
681 return
682 }
683 }
684
685 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.generateID())
686
687 args, err := json.Marshal(part.FunctionCall.Args)
688 if err != nil {
689 yield(fantasy.StreamPart{
690 Type: fantasy.StreamPartTypeError,
691 Error: err,
692 })
693 return
694 }
695
696 if !yield(fantasy.StreamPart{
697 Type: fantasy.StreamPartTypeToolInputStart,
698 ID: toolCallID,
699 ToolCallName: part.FunctionCall.Name,
700 }) {
701 return
702 }
703
704 if !yield(fantasy.StreamPart{
705 Type: fantasy.StreamPartTypeToolInputDelta,
706 ID: toolCallID,
707 Delta: string(args),
708 }) {
709 return
710 }
711
712 if !yield(fantasy.StreamPart{
713 Type: fantasy.StreamPartTypeToolInputEnd,
714 ID: toolCallID,
715 }) {
716 return
717 }
718
719 if !yield(fantasy.StreamPart{
720 Type: fantasy.StreamPartTypeToolCall,
721 ID: toolCallID,
722 ToolCallName: part.FunctionCall.Name,
723 ToolCallInput: string(args),
724 ProviderExecuted: false,
725 }) {
726 return
727 }
728
729 toolCalls = append(toolCalls, fantasy.ToolCallContent{
730 ToolCallID: toolCallID,
731 ToolName: part.FunctionCall.Name,
732 Input: string(args),
733 ProviderExecuted: false,
734 })
735 }
736 }
737 }
738
739 if resp.UsageMetadata != nil {
740 currentUsage := mapUsage(resp.UsageMetadata)
741 // if first usage chunk
742 if usage == nil {
743 usage = ¤tUsage
744 } else {
745 usage.OutputTokens += currentUsage.OutputTokens
746 usage.ReasoningTokens += currentUsage.ReasoningTokens
747 usage.CacheReadTokens += currentUsage.CacheReadTokens
748 }
749 }
750
751 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
752 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
753 }
754 }
755
756 // Close any open blocks before finishing
757 if isActiveText {
758 if !yield(fantasy.StreamPart{
759 Type: fantasy.StreamPartTypeTextEnd,
760 ID: currentTextBlockID,
761 }) {
762 return
763 }
764 }
765 if isActiveReasoning {
766 if !yield(fantasy.StreamPart{
767 Type: fantasy.StreamPartTypeReasoningEnd,
768 ID: currentReasoningBlockID,
769 }) {
770 return
771 }
772 }
773
774 finishReason := lastFinishReason
775 if len(toolCalls) > 0 {
776 finishReason = fantasy.FinishReasonToolCalls
777 } else if finishReason == "" {
778 finishReason = fantasy.FinishReasonStop
779 }
780
781 yield(fantasy.StreamPart{
782 Type: fantasy.StreamPartTypeFinish,
783 Usage: *usage,
784 FinishReason: finishReason,
785 })
786 }, nil
787}
788
789func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
790 for _, tool := range tools {
791 if tool.GetType() == fantasy.ToolTypeFunction {
792 ft, ok := tool.(fantasy.FunctionTool)
793 if !ok {
794 continue
795 }
796
797 required := []string{}
798 var properties map[string]any
799 if props, ok := ft.InputSchema["properties"]; ok {
800 properties, _ = props.(map[string]any)
801 }
802 if req, ok := ft.InputSchema["required"]; ok {
803 if reqArr, ok := req.([]string); ok {
804 required = reqArr
805 }
806 }
807 declaration := &genai.FunctionDeclaration{
808 Name: ft.Name,
809 Description: ft.Description,
810 Parameters: &genai.Schema{
811 Type: genai.TypeObject,
812 Properties: convertSchemaProperties(properties),
813 Required: required,
814 },
815 }
816 googleTools = append(googleTools, declaration)
817 continue
818 }
819 // TODO: handle provider tool calls
820 warnings = append(warnings, fantasy.CallWarning{
821 Type: fantasy.CallWarningTypeUnsupportedTool,
822 Tool: tool,
823 Message: "tool is not supported",
824 })
825 }
826 if toolChoice == nil {
827 return googleTools, googleToolChoice, warnings
828 }
829 switch *toolChoice {
830 case fantasy.ToolChoiceAuto:
831 googleToolChoice = &genai.ToolConfig{
832 FunctionCallingConfig: &genai.FunctionCallingConfig{
833 Mode: genai.FunctionCallingConfigModeAuto,
834 },
835 }
836 case fantasy.ToolChoiceRequired:
837 googleToolChoice = &genai.ToolConfig{
838 FunctionCallingConfig: &genai.FunctionCallingConfig{
839 Mode: genai.FunctionCallingConfigModeAny,
840 },
841 }
842 case fantasy.ToolChoiceNone:
843 googleToolChoice = &genai.ToolConfig{
844 FunctionCallingConfig: &genai.FunctionCallingConfig{
845 Mode: genai.FunctionCallingConfigModeNone,
846 },
847 }
848 default:
849 googleToolChoice = &genai.ToolConfig{
850 FunctionCallingConfig: &genai.FunctionCallingConfig{
851 Mode: genai.FunctionCallingConfigModeAny,
852 AllowedFunctionNames: []string{
853 string(*toolChoice),
854 },
855 },
856 }
857 }
858 return googleTools, googleToolChoice, warnings
859}
860
861func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
862 properties := make(map[string]*genai.Schema)
863
864 for name, param := range parameters {
865 properties[name] = convertToSchema(param)
866 }
867
868 return properties
869}
870
871func convertToSchema(param any) *genai.Schema {
872 schema := &genai.Schema{Type: genai.TypeString}
873
874 paramMap, ok := param.(map[string]any)
875 if !ok {
876 return schema
877 }
878
879 if desc, ok := paramMap["description"].(string); ok {
880 schema.Description = desc
881 }
882
883 typeVal, hasType := paramMap["type"]
884 if !hasType {
885 return schema
886 }
887
888 typeStr, ok := typeVal.(string)
889 if !ok {
890 return schema
891 }
892
893 schema.Type = mapJSONTypeToGoogle(typeStr)
894
895 switch typeStr {
896 case "array":
897 schema.Items = processArrayItems(paramMap)
898 case "object":
899 if props, ok := paramMap["properties"].(map[string]any); ok {
900 schema.Properties = convertSchemaProperties(props)
901 }
902 }
903
904 return schema
905}
906
907func processArrayItems(paramMap map[string]any) *genai.Schema {
908 items, ok := paramMap["items"].(map[string]any)
909 if !ok {
910 return nil
911 }
912
913 return convertToSchema(items)
914}
915
916func mapJSONTypeToGoogle(jsonType string) genai.Type {
917 switch jsonType {
918 case "string":
919 return genai.TypeString
920 case "number":
921 return genai.TypeNumber
922 case "integer":
923 return genai.TypeInteger
924 case "boolean":
925 return genai.TypeBoolean
926 case "array":
927 return genai.TypeArray
928 case "object":
929 return genai.TypeObject
930 default:
931 return genai.TypeString // Default to string for unknown types
932 }
933}
934
935func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
936 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
937 return nil, errors.New("no response from model")
938 }
939
940 var (
941 content []fantasy.Content
942 finishReason fantasy.FinishReason
943 hasToolCalls bool
944 candidate = response.Candidates[0]
945 )
946
947 for _, part := range candidate.Content.Parts {
948 switch {
949 case part.Text != "":
950 if part.Thought {
951 metadata := &ReasoningMetadata{
952 Signature: string(part.ThoughtSignature),
953 }
954 content = append(content, fantasy.ReasoningContent{Text: part.Text, ProviderMetadata: fantasy.ProviderMetadata{Name: metadata}})
955 } else {
956 content = append(content, fantasy.TextContent{Text: part.Text})
957 }
958 case part.FunctionCall != nil:
959 input, err := json.Marshal(part.FunctionCall.Args)
960 if err != nil {
961 return nil, err
962 }
963 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.generateID())
964 content = append(content, fantasy.ToolCallContent{
965 ToolCallID: toolCallID,
966 ToolName: part.FunctionCall.Name,
967 Input: string(input),
968 ProviderExecuted: false,
969 })
970 hasToolCalls = true
971 default:
972 // Silently skip unknown part types instead of erroring
973 // This allows for forward compatibility with new part types
974 }
975 }
976
977 if hasToolCalls {
978 finishReason = fantasy.FinishReasonToolCalls
979 } else {
980 finishReason = mapFinishReason(candidate.FinishReason)
981 }
982
983 return &fantasy.Response{
984 Content: content,
985 Usage: mapUsage(response.UsageMetadata),
986 FinishReason: finishReason,
987 Warnings: warnings,
988 }, nil
989}
990
991func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
992 switch reason {
993 case genai.FinishReasonStop:
994 return fantasy.FinishReasonStop
995 case genai.FinishReasonMaxTokens:
996 return fantasy.FinishReasonLength
997 case genai.FinishReasonSafety,
998 genai.FinishReasonBlocklist,
999 genai.FinishReasonProhibitedContent,
1000 genai.FinishReasonSPII,
1001 genai.FinishReasonImageSafety:
1002 return fantasy.FinishReasonContentFilter
1003 case genai.FinishReasonRecitation,
1004 genai.FinishReasonLanguage,
1005 genai.FinishReasonMalformedFunctionCall:
1006 return fantasy.FinishReasonError
1007 case genai.FinishReasonOther:
1008 return fantasy.FinishReasonOther
1009 default:
1010 return fantasy.FinishReasonUnknown
1011 }
1012}
1013
1014func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1015 return fantasy.Usage{
1016 InputTokens: int64(usage.PromptTokenCount),
1017 OutputTokens: int64(usage.CandidatesTokenCount),
1018 TotalTokens: int64(usage.TotalTokenCount),
1019 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1020 CacheCreationTokens: 0,
1021 CacheReadTokens: int64(usage.CachedContentTokenCount),
1022 }
1023}