1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "charm.land/fantasy"
15 "charm.land/fantasy/providers/anthropic"
16 "cloud.google.com/go/auth"
17 "github.com/charmbracelet/x/exp/slice"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22// Name is the name of the Google provider.
23const Name = "google"
24
25type provider struct {
26 options options
27}
28
29type options struct {
30 apiKey string
31 name string
32 baseURL string
33 headers map[string]string
34 client *http.Client
35 backend genai.Backend
36 project string
37 location string
38 skipAuth bool
39}
40
41// Option defines a function that configures Google provider options.
42type Option = func(*options)
43
44// New creates a new Google provider with the given options.
45func New(opts ...Option) fantasy.Provider {
46 options := options{
47 headers: map[string]string{},
48 }
49 for _, o := range opts {
50 o(&options)
51 }
52
53 options.name = cmp.Or(options.name, Name)
54
55 return &provider{
56 options: options,
57 }
58}
59
60// WithBaseURL sets the base URL for the Google provider.
61func WithBaseURL(baseURL string) Option {
62 return func(o *options) {
63 o.baseURL = baseURL
64 }
65}
66
67// WithGeminiAPIKey sets the Gemini API key for the Google provider.
68func WithGeminiAPIKey(apiKey string) Option {
69 return func(o *options) {
70 o.backend = genai.BackendGeminiAPI
71 o.apiKey = apiKey
72 o.project = ""
73 o.location = ""
74 }
75}
76
77// WithVertex configures the Google provider to use Vertex AI.
78func WithVertex(project, location string) Option {
79 if project == "" || location == "" {
80 panic("project and location must be provided")
81 }
82 return func(o *options) {
83 o.backend = genai.BackendVertexAI
84 o.apiKey = ""
85 o.project = project
86 o.location = location
87 }
88}
89
90// WithSkipAuth configures whether to skip authentication for the Google provider.
91func WithSkipAuth(skipAuth bool) Option {
92 return func(o *options) {
93 o.skipAuth = skipAuth
94 }
95}
96
97// WithName sets the name for the Google provider.
98func WithName(name string) Option {
99 return func(o *options) {
100 o.name = name
101 }
102}
103
104// WithHeaders sets the headers for the Google provider.
105func WithHeaders(headers map[string]string) Option {
106 return func(o *options) {
107 maps.Copy(o.headers, headers)
108 }
109}
110
111// WithHTTPClient sets the HTTP client for the Google provider.
112func WithHTTPClient(client *http.Client) Option {
113 return func(o *options) {
114 o.client = client
115 }
116}
117
118func (*provider) Name() string {
119 return Name
120}
121
122type languageModel struct {
123 provider string
124 modelID string
125 client *genai.Client
126 providerOptions options
127}
128
129// LanguageModel implements fantasy.Provider.
130func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
131 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
132 return anthropic.New(
133 anthropic.WithVertex(a.options.project, a.options.location),
134 anthropic.WithHTTPClient(a.options.client),
135 anthropic.WithSkipAuth(a.options.skipAuth),
136 ).LanguageModel(modelID)
137 }
138
139 cc := &genai.ClientConfig{
140 HTTPClient: a.options.client,
141 Backend: a.options.backend,
142 APIKey: a.options.apiKey,
143 Project: a.options.project,
144 Location: a.options.location,
145 }
146 if a.options.skipAuth {
147 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
148 }
149
150 if a.options.baseURL != "" || len(a.options.headers) > 0 {
151 headers := http.Header{}
152 for k, v := range a.options.headers {
153 headers.Add(k, v)
154 }
155 cc.HTTPOptions = genai.HTTPOptions{
156 BaseURL: a.options.baseURL,
157 Headers: headers,
158 }
159 }
160 client, err := genai.NewClient(context.Background(), cc)
161 if err != nil {
162 return nil, err
163 }
164 return &languageModel{
165 modelID: modelID,
166 provider: a.options.name,
167 providerOptions: a.options,
168 client: client,
169 }, nil
170}
171
172func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
173 config := &genai.GenerateContentConfig{}
174
175 providerOptions := &ProviderOptions{}
176 if v, ok := call.ProviderOptions[Name]; ok {
177 providerOptions, ok = v.(*ProviderOptions)
178 if !ok {
179 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
180 }
181 }
182
183 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
184
185 if providerOptions.ThinkingConfig != nil {
186 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
187 *providerOptions.ThinkingConfig.IncludeThoughts &&
188 strings.HasPrefix(g.provider, "google.vertex.") {
189 warnings = append(warnings, fantasy.CallWarning{
190 Type: fantasy.CallWarningTypeOther,
191 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
192 "and might not be supported or could behave unexpectedly with the current Google provider " +
193 fmt.Sprintf("(%s)", g.provider),
194 })
195 }
196
197 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
198 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
199 warnings = append(warnings, fantasy.CallWarning{
200 Type: fantasy.CallWarningTypeOther,
201 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
202 })
203 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
204 }
205 }
206
207 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
208
209 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
210 if len(content) > 0 && content[0].Role == genai.RoleUser {
211 systemParts := []string{}
212 for _, sp := range systemInstructions.Parts {
213 systemParts = append(systemParts, sp.Text)
214 }
215 systemMsg := strings.Join(systemParts, "\n")
216 content[0].Parts = append([]*genai.Part{
217 {
218 Text: systemMsg + "\n\n",
219 },
220 }, content[0].Parts...)
221 systemInstructions = nil
222 }
223 }
224
225 config.SystemInstruction = systemInstructions
226
227 if call.MaxOutputTokens != nil {
228 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
229 }
230
231 if call.Temperature != nil {
232 tmp := float32(*call.Temperature)
233 config.Temperature = &tmp
234 }
235 if call.TopK != nil {
236 tmp := float32(*call.TopK)
237 config.TopK = &tmp
238 }
239 if call.TopP != nil {
240 tmp := float32(*call.TopP)
241 config.TopP = &tmp
242 }
243 if call.FrequencyPenalty != nil {
244 tmp := float32(*call.FrequencyPenalty)
245 config.FrequencyPenalty = &tmp
246 }
247 if call.PresencePenalty != nil {
248 tmp := float32(*call.PresencePenalty)
249 config.PresencePenalty = &tmp
250 }
251
252 if providerOptions.ThinkingConfig != nil {
253 config.ThinkingConfig = &genai.ThinkingConfig{}
254 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
255 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
256 }
257 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
258 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
259 config.ThinkingConfig.ThinkingBudget = &tmp
260 }
261 }
262 for _, safetySetting := range providerOptions.SafetySettings {
263 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
264 Category: genai.HarmCategory(safetySetting.Category),
265 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
266 })
267 }
268 if providerOptions.CachedContent != "" {
269 config.CachedContent = providerOptions.CachedContent
270 }
271
272 if len(call.Tools) > 0 {
273 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
274 config.ToolConfig = toolChoice
275 config.Tools = append(config.Tools, &genai.Tool{
276 FunctionDeclarations: tools,
277 })
278 warnings = append(warnings, toolWarnings...)
279 }
280
281 return config, content, warnings, nil
282}
283
284func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
285 var systemInstructions *genai.Content
286 var content []*genai.Content
287 var warnings []fantasy.CallWarning
288
289 finishedSystemBlock := false
290 for _, msg := range prompt {
291 switch msg.Role {
292 case fantasy.MessageRoleSystem:
293 if finishedSystemBlock {
294 // skip multiple system messages that are separated by user/assistant messages
295 // TODO: see if we need to send error here?
296 continue
297 }
298 finishedSystemBlock = true
299
300 var systemMessages []string
301 for _, part := range msg.Content {
302 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
303 if !ok || text.Text == "" {
304 continue
305 }
306 systemMessages = append(systemMessages, text.Text)
307 }
308 if len(systemMessages) > 0 {
309 systemInstructions = &genai.Content{
310 Parts: []*genai.Part{
311 {
312 Text: strings.Join(systemMessages, "\n"),
313 },
314 },
315 }
316 }
317 case fantasy.MessageRoleUser:
318 var parts []*genai.Part
319 for _, part := range msg.Content {
320 switch part.GetType() {
321 case fantasy.ContentTypeText:
322 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
323 if !ok || text.Text == "" {
324 continue
325 }
326 parts = append(parts, &genai.Part{
327 Text: text.Text,
328 })
329 case fantasy.ContentTypeFile:
330 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
331 if !ok {
332 continue
333 }
334 var encoded []byte
335 base64.StdEncoding.Encode(encoded, file.Data)
336 parts = append(parts, &genai.Part{
337 InlineData: &genai.Blob{
338 Data: encoded,
339 MIMEType: file.MediaType,
340 },
341 })
342 }
343 }
344 if len(parts) > 0 {
345 content = append(content, &genai.Content{
346 Role: genai.RoleUser,
347 Parts: parts,
348 })
349 }
350 case fantasy.MessageRoleAssistant:
351 var parts []*genai.Part
352 for _, part := range msg.Content {
353 switch part.GetType() {
354 case fantasy.ContentTypeText:
355 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
356 if !ok || text.Text == "" {
357 continue
358 }
359 parts = append(parts, &genai.Part{
360 Text: text.Text,
361 })
362 case fantasy.ContentTypeToolCall:
363 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
364 if !ok {
365 continue
366 }
367
368 var result map[string]any
369 err := json.Unmarshal([]byte(toolCall.Input), &result)
370 if err != nil {
371 continue
372 }
373 parts = append(parts, &genai.Part{
374 FunctionCall: &genai.FunctionCall{
375 ID: toolCall.ToolCallID,
376 Name: toolCall.ToolName,
377 Args: result,
378 },
379 })
380 }
381 }
382 if len(parts) > 0 {
383 content = append(content, &genai.Content{
384 Role: genai.RoleModel,
385 Parts: parts,
386 })
387 }
388 case fantasy.MessageRoleTool:
389 var parts []*genai.Part
390 for _, part := range msg.Content {
391 switch part.GetType() {
392 case fantasy.ContentTypeToolResult:
393 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
394 if !ok {
395 continue
396 }
397 var toolCall fantasy.ToolCallPart
398 for _, m := range prompt {
399 if m.Role == fantasy.MessageRoleAssistant {
400 for _, content := range m.Content {
401 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
402 if !ok {
403 continue
404 }
405 if tc.ToolCallID == result.ToolCallID {
406 toolCall = tc
407 break
408 }
409 }
410 }
411 }
412 switch result.Output.GetType() {
413 case fantasy.ToolResultContentTypeText:
414 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
415 if !ok {
416 continue
417 }
418 response := map[string]any{"result": content.Text}
419 parts = append(parts, &genai.Part{
420 FunctionResponse: &genai.FunctionResponse{
421 ID: result.ToolCallID,
422 Response: response,
423 Name: toolCall.ToolName,
424 },
425 })
426
427 case fantasy.ToolResultContentTypeError:
428 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
429 if !ok {
430 continue
431 }
432 response := map[string]any{"result": content.Error.Error()}
433 parts = append(parts, &genai.Part{
434 FunctionResponse: &genai.FunctionResponse{
435 ID: result.ToolCallID,
436 Response: response,
437 Name: toolCall.ToolName,
438 },
439 })
440 }
441 }
442 }
443 if len(parts) > 0 {
444 content = append(content, &genai.Content{
445 Role: genai.RoleUser,
446 Parts: parts,
447 })
448 }
449 default:
450 panic("unsupported message role: " + msg.Role)
451 }
452 }
453 return systemInstructions, content, warnings
454}
455
456// Generate implements fantasy.LanguageModel.
457func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
458 config, contents, warnings, err := g.prepareParams(call)
459 if err != nil {
460 return nil, err
461 }
462
463 lastMessage, history, ok := slice.Pop(contents)
464 if !ok {
465 return nil, errors.New("no messages to send")
466 }
467
468 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
469 if err != nil {
470 return nil, err
471 }
472
473 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
474 if err != nil {
475 return nil, err
476 }
477
478 return mapResponse(response, warnings)
479}
480
481// Model implements fantasy.LanguageModel.
482func (g *languageModel) Model() string {
483 return g.modelID
484}
485
486// Provider implements fantasy.LanguageModel.
487func (g *languageModel) Provider() string {
488 return g.provider
489}
490
491// Stream implements fantasy.LanguageModel.
492func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
493 config, contents, warnings, err := g.prepareParams(call)
494 if err != nil {
495 return nil, err
496 }
497
498 lastMessage, history, ok := slice.Pop(contents)
499 if !ok {
500 return nil, errors.New("no messages to send")
501 }
502
503 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
504 if err != nil {
505 return nil, err
506 }
507
508 return func(yield func(fantasy.StreamPart) bool) {
509 if len(warnings) > 0 {
510 if !yield(fantasy.StreamPart{
511 Type: fantasy.StreamPartTypeWarnings,
512 Warnings: warnings,
513 }) {
514 return
515 }
516 }
517
518 var currentContent string
519 var toolCalls []fantasy.ToolCallContent
520 var isActiveText bool
521 var isActiveReasoning bool
522 var blockCounter int
523 var currentTextBlockID string
524 var currentReasoningBlockID string
525 var usage fantasy.Usage
526 var lastFinishReason fantasy.FinishReason
527
528 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
529 if err != nil {
530 yield(fantasy.StreamPart{
531 Type: fantasy.StreamPartTypeError,
532 Error: err,
533 })
534 return
535 }
536
537 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
538 for _, part := range resp.Candidates[0].Content.Parts {
539 switch {
540 case part.Text != "":
541 delta := part.Text
542 if delta != "" {
543 // Check if this is a reasoning/thought part
544 if part.Thought {
545 // End any active text block before starting reasoning
546 if isActiveText {
547 isActiveText = false
548 if !yield(fantasy.StreamPart{
549 Type: fantasy.StreamPartTypeTextEnd,
550 ID: currentTextBlockID,
551 }) {
552 return
553 }
554 }
555
556 // Start new reasoning block if not already active
557 if !isActiveReasoning {
558 isActiveReasoning = true
559 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
560 blockCounter++
561 if !yield(fantasy.StreamPart{
562 Type: fantasy.StreamPartTypeReasoningStart,
563 ID: currentReasoningBlockID,
564 }) {
565 return
566 }
567 }
568
569 if !yield(fantasy.StreamPart{
570 Type: fantasy.StreamPartTypeReasoningDelta,
571 ID: currentReasoningBlockID,
572 Delta: delta,
573 }) {
574 return
575 }
576 } else {
577 // Regular text part
578 // End any active reasoning block before starting text
579 if isActiveReasoning {
580 isActiveReasoning = false
581 if !yield(fantasy.StreamPart{
582 Type: fantasy.StreamPartTypeReasoningEnd,
583 ID: currentReasoningBlockID,
584 }) {
585 return
586 }
587 }
588
589 // Start new text block if not already active
590 if !isActiveText {
591 isActiveText = true
592 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
593 blockCounter++
594 if !yield(fantasy.StreamPart{
595 Type: fantasy.StreamPartTypeTextStart,
596 ID: currentTextBlockID,
597 }) {
598 return
599 }
600 }
601
602 if !yield(fantasy.StreamPart{
603 Type: fantasy.StreamPartTypeTextDelta,
604 ID: currentTextBlockID,
605 Delta: delta,
606 }) {
607 return
608 }
609 currentContent += delta
610 }
611 }
612 case part.FunctionCall != nil:
613 // End any active text or reasoning blocks
614 if isActiveText {
615 isActiveText = false
616 if !yield(fantasy.StreamPart{
617 Type: fantasy.StreamPartTypeTextEnd,
618 ID: currentTextBlockID,
619 }) {
620 return
621 }
622 }
623 if isActiveReasoning {
624 isActiveReasoning = false
625 if !yield(fantasy.StreamPart{
626 Type: fantasy.StreamPartTypeReasoningEnd,
627 ID: currentReasoningBlockID,
628 }) {
629 return
630 }
631 }
632
633 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
634
635 args, err := json.Marshal(part.FunctionCall.Args)
636 if err != nil {
637 yield(fantasy.StreamPart{
638 Type: fantasy.StreamPartTypeError,
639 Error: err,
640 })
641 return
642 }
643
644 if !yield(fantasy.StreamPart{
645 Type: fantasy.StreamPartTypeToolInputStart,
646 ID: toolCallID,
647 ToolCallName: part.FunctionCall.Name,
648 }) {
649 return
650 }
651
652 if !yield(fantasy.StreamPart{
653 Type: fantasy.StreamPartTypeToolInputDelta,
654 ID: toolCallID,
655 Delta: string(args),
656 }) {
657 return
658 }
659
660 if !yield(fantasy.StreamPart{
661 Type: fantasy.StreamPartTypeToolInputEnd,
662 ID: toolCallID,
663 }) {
664 return
665 }
666
667 if !yield(fantasy.StreamPart{
668 Type: fantasy.StreamPartTypeToolCall,
669 ID: toolCallID,
670 ToolCallName: part.FunctionCall.Name,
671 ToolCallInput: string(args),
672 ProviderExecuted: false,
673 }) {
674 return
675 }
676
677 toolCalls = append(toolCalls, fantasy.ToolCallContent{
678 ToolCallID: toolCallID,
679 ToolName: part.FunctionCall.Name,
680 Input: string(args),
681 ProviderExecuted: false,
682 })
683 }
684 }
685 }
686
687 if resp.UsageMetadata != nil {
688 usage = mapUsage(resp.UsageMetadata)
689 }
690
691 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
692 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
693 }
694 }
695
696 // Close any open blocks before finishing
697 if isActiveText {
698 if !yield(fantasy.StreamPart{
699 Type: fantasy.StreamPartTypeTextEnd,
700 ID: currentTextBlockID,
701 }) {
702 return
703 }
704 }
705 if isActiveReasoning {
706 if !yield(fantasy.StreamPart{
707 Type: fantasy.StreamPartTypeReasoningEnd,
708 ID: currentReasoningBlockID,
709 }) {
710 return
711 }
712 }
713
714 finishReason := lastFinishReason
715 if len(toolCalls) > 0 {
716 finishReason = fantasy.FinishReasonToolCalls
717 } else if finishReason == "" {
718 finishReason = fantasy.FinishReasonStop
719 }
720
721 yield(fantasy.StreamPart{
722 Type: fantasy.StreamPartTypeFinish,
723 Usage: usage,
724 FinishReason: finishReason,
725 })
726 }, nil
727}
728
729func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
730 for _, tool := range tools {
731 if tool.GetType() == fantasy.ToolTypeFunction {
732 ft, ok := tool.(fantasy.FunctionTool)
733 if !ok {
734 continue
735 }
736
737 required := []string{}
738 var properties map[string]any
739 if props, ok := ft.InputSchema["properties"]; ok {
740 properties, _ = props.(map[string]any)
741 }
742 if req, ok := ft.InputSchema["required"]; ok {
743 if reqArr, ok := req.([]string); ok {
744 required = reqArr
745 }
746 }
747 declaration := &genai.FunctionDeclaration{
748 Name: ft.Name,
749 Description: ft.Description,
750 Parameters: &genai.Schema{
751 Type: genai.TypeObject,
752 Properties: convertSchemaProperties(properties),
753 Required: required,
754 },
755 }
756 googleTools = append(googleTools, declaration)
757 continue
758 }
759 // TODO: handle provider tool calls
760 warnings = append(warnings, fantasy.CallWarning{
761 Type: fantasy.CallWarningTypeUnsupportedTool,
762 Tool: tool,
763 Message: "tool is not supported",
764 })
765 }
766 if toolChoice == nil {
767 return googleTools, googleToolChoice, warnings
768 }
769 switch *toolChoice {
770 case fantasy.ToolChoiceAuto:
771 googleToolChoice = &genai.ToolConfig{
772 FunctionCallingConfig: &genai.FunctionCallingConfig{
773 Mode: genai.FunctionCallingConfigModeAuto,
774 },
775 }
776 case fantasy.ToolChoiceRequired:
777 googleToolChoice = &genai.ToolConfig{
778 FunctionCallingConfig: &genai.FunctionCallingConfig{
779 Mode: genai.FunctionCallingConfigModeAny,
780 },
781 }
782 case fantasy.ToolChoiceNone:
783 googleToolChoice = &genai.ToolConfig{
784 FunctionCallingConfig: &genai.FunctionCallingConfig{
785 Mode: genai.FunctionCallingConfigModeNone,
786 },
787 }
788 default:
789 googleToolChoice = &genai.ToolConfig{
790 FunctionCallingConfig: &genai.FunctionCallingConfig{
791 Mode: genai.FunctionCallingConfigModeAny,
792 AllowedFunctionNames: []string{
793 string(*toolChoice),
794 },
795 },
796 }
797 }
798 return googleTools, googleToolChoice, warnings
799}
800
801func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
802 properties := make(map[string]*genai.Schema)
803
804 for name, param := range parameters {
805 properties[name] = convertToSchema(param)
806 }
807
808 return properties
809}
810
811func convertToSchema(param any) *genai.Schema {
812 schema := &genai.Schema{Type: genai.TypeString}
813
814 paramMap, ok := param.(map[string]any)
815 if !ok {
816 return schema
817 }
818
819 if desc, ok := paramMap["description"].(string); ok {
820 schema.Description = desc
821 }
822
823 typeVal, hasType := paramMap["type"]
824 if !hasType {
825 return schema
826 }
827
828 typeStr, ok := typeVal.(string)
829 if !ok {
830 return schema
831 }
832
833 schema.Type = mapJSONTypeToGoogle(typeStr)
834
835 switch typeStr {
836 case "array":
837 schema.Items = processArrayItems(paramMap)
838 case "object":
839 if props, ok := paramMap["properties"].(map[string]any); ok {
840 schema.Properties = convertSchemaProperties(props)
841 }
842 }
843
844 return schema
845}
846
847func processArrayItems(paramMap map[string]any) *genai.Schema {
848 items, ok := paramMap["items"].(map[string]any)
849 if !ok {
850 return nil
851 }
852
853 return convertToSchema(items)
854}
855
856func mapJSONTypeToGoogle(jsonType string) genai.Type {
857 switch jsonType {
858 case "string":
859 return genai.TypeString
860 case "number":
861 return genai.TypeNumber
862 case "integer":
863 return genai.TypeInteger
864 case "boolean":
865 return genai.TypeBoolean
866 case "array":
867 return genai.TypeArray
868 case "object":
869 return genai.TypeObject
870 default:
871 return genai.TypeString // Default to string for unknown types
872 }
873}
874
875func mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
876 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
877 return nil, errors.New("no response from model")
878 }
879
880 var (
881 content []fantasy.Content
882 finishReason fantasy.FinishReason
883 hasToolCalls bool
884 candidate = response.Candidates[0]
885 )
886
887 for _, part := range candidate.Content.Parts {
888 switch {
889 case part.Text != "":
890 if part.Thought {
891 content = append(content, fantasy.ReasoningContent{Text: part.Text})
892 } else {
893 content = append(content, fantasy.TextContent{Text: part.Text})
894 }
895 case part.FunctionCall != nil:
896 input, err := json.Marshal(part.FunctionCall.Args)
897 if err != nil {
898 return nil, err
899 }
900 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
901 content = append(content, fantasy.ToolCallContent{
902 ToolCallID: toolCallID,
903 ToolName: part.FunctionCall.Name,
904 Input: string(input),
905 ProviderExecuted: false,
906 })
907 hasToolCalls = true
908 default:
909 // Silently skip unknown part types instead of erroring
910 // This allows for forward compatibility with new part types
911 }
912 }
913
914 if hasToolCalls {
915 finishReason = fantasy.FinishReasonToolCalls
916 } else {
917 finishReason = mapFinishReason(candidate.FinishReason)
918 }
919
920 return &fantasy.Response{
921 Content: content,
922 Usage: mapUsage(response.UsageMetadata),
923 FinishReason: finishReason,
924 Warnings: warnings,
925 }, nil
926}
927
928func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
929 switch reason {
930 case genai.FinishReasonStop:
931 return fantasy.FinishReasonStop
932 case genai.FinishReasonMaxTokens:
933 return fantasy.FinishReasonLength
934 case genai.FinishReasonSafety,
935 genai.FinishReasonBlocklist,
936 genai.FinishReasonProhibitedContent,
937 genai.FinishReasonSPII,
938 genai.FinishReasonImageSafety:
939 return fantasy.FinishReasonContentFilter
940 case genai.FinishReasonRecitation,
941 genai.FinishReasonLanguage,
942 genai.FinishReasonMalformedFunctionCall:
943 return fantasy.FinishReasonError
944 case genai.FinishReasonOther:
945 return fantasy.FinishReasonOther
946 default:
947 return fantasy.FinishReasonUnknown
948 }
949}
950
951func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
952 return fantasy.Usage{
953 InputTokens: int64(usage.ToolUsePromptTokenCount),
954 OutputTokens: int64(usage.CandidatesTokenCount),
955 TotalTokens: int64(usage.TotalTokenCount),
956 ReasoningTokens: int64(usage.ThoughtsTokenCount),
957 CacheCreationTokens: int64(usage.CachedContentTokenCount),
958 CacheReadTokens: 0,
959 }
960}