1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "charm.land/fantasy"
15 "charm.land/fantasy/anthropic"
16 "cloud.google.com/go/auth"
17 "github.com/charmbracelet/x/exp/slice"
18 "github.com/google/uuid"
19 "google.golang.org/genai"
20)
21
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28type options struct {
29 apiKey string
30 name string
31 baseURL string
32 headers map[string]string
33 client *http.Client
34 backend genai.Backend
35 project string
36 location string
37 skipAuth bool
38}
39
40type Option = func(*options)
41
42func New(opts ...Option) fantasy.Provider {
43 options := options{
44 headers: map[string]string{},
45 }
46 for _, o := range opts {
47 o(&options)
48 }
49
50 options.name = cmp.Or(options.name, Name)
51
52 return &provider{
53 options: options,
54 }
55}
56
57func WithBaseURL(baseURL string) Option {
58 return func(o *options) {
59 o.baseURL = baseURL
60 }
61}
62
63func WithGeminiAPIKey(apiKey string) Option {
64 return func(o *options) {
65 o.backend = genai.BackendGeminiAPI
66 o.apiKey = apiKey
67 o.project = ""
68 o.location = ""
69 }
70}
71
72func WithVertex(project, location string) Option {
73 if project == "" || location == "" {
74 panic("project and location must be provided")
75 }
76 return func(o *options) {
77 o.backend = genai.BackendVertexAI
78 o.apiKey = ""
79 o.project = project
80 o.location = location
81 }
82}
83
84func WithSkipAuth(skipAuth bool) Option {
85 return func(o *options) {
86 o.skipAuth = skipAuth
87 }
88}
89
90func WithName(name string) Option {
91 return func(o *options) {
92 o.name = name
93 }
94}
95
96func WithHeaders(headers map[string]string) Option {
97 return func(o *options) {
98 maps.Copy(o.headers, headers)
99 }
100}
101
102func WithHTTPClient(client *http.Client) Option {
103 return func(o *options) {
104 o.client = client
105 }
106}
107
108func (*provider) Name() string {
109 return Name
110}
111
112type languageModel struct {
113 provider string
114 modelID string
115 client *genai.Client
116 providerOptions options
117}
118
119// LanguageModel implements fantasy.Provider.
120func (g *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
121 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
122 return anthropic.New(
123 anthropic.WithVertex(g.options.project, g.options.location),
124 anthropic.WithHTTPClient(g.options.client),
125 anthropic.WithSkipAuth(g.options.skipAuth),
126 ).LanguageModel(modelID)
127 }
128
129 cc := &genai.ClientConfig{
130 HTTPClient: g.options.client,
131 Backend: g.options.backend,
132 APIKey: g.options.apiKey,
133 Project: g.options.project,
134 Location: g.options.location,
135 }
136 if g.options.skipAuth {
137 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
138 }
139
140 if g.options.baseURL != "" || len(g.options.headers) > 0 {
141 headers := http.Header{}
142 for k, v := range g.options.headers {
143 headers.Add(k, v)
144 }
145 cc.HTTPOptions = genai.HTTPOptions{
146 BaseURL: g.options.baseURL,
147 Headers: headers,
148 }
149 }
150 client, err := genai.NewClient(context.Background(), cc)
151 if err != nil {
152 return nil, err
153 }
154 return &languageModel{
155 modelID: modelID,
156 provider: g.options.name,
157 providerOptions: g.options,
158 client: client,
159 }, nil
160}
161
162func (a languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
163 config := &genai.GenerateContentConfig{}
164
165 providerOptions := &ProviderOptions{}
166 if v, ok := call.ProviderOptions[Name]; ok {
167 providerOptions, ok = v.(*ProviderOptions)
168 if !ok {
169 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
170 }
171 }
172
173 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
174
175 if providerOptions.ThinkingConfig != nil {
176 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
177 *providerOptions.ThinkingConfig.IncludeThoughts &&
178 strings.HasPrefix(a.provider, "google.vertex.") {
179 warnings = append(warnings, fantasy.CallWarning{
180 Type: fantasy.CallWarningTypeOther,
181 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
182 "and might not be supported or could behave unexpectedly with the current Google provider " +
183 fmt.Sprintf("(%s)", a.provider),
184 })
185 }
186
187 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
188 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
189 warnings = append(warnings, fantasy.CallWarning{
190 Type: fantasy.CallWarningTypeOther,
191 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
192 })
193 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
194 }
195 }
196
197 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
198
199 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
200 if len(content) > 0 && content[0].Role == genai.RoleUser {
201 systemParts := []string{}
202 for _, sp := range systemInstructions.Parts {
203 systemParts = append(systemParts, sp.Text)
204 }
205 systemMsg := strings.Join(systemParts, "\n")
206 content[0].Parts = append([]*genai.Part{
207 {
208 Text: systemMsg + "\n\n",
209 },
210 }, content[0].Parts...)
211 systemInstructions = nil
212 }
213 }
214
215 config.SystemInstruction = systemInstructions
216
217 if call.MaxOutputTokens != nil {
218 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
219 }
220
221 if call.Temperature != nil {
222 tmp := float32(*call.Temperature)
223 config.Temperature = &tmp
224 }
225 if call.TopK != nil {
226 tmp := float32(*call.TopK)
227 config.TopK = &tmp
228 }
229 if call.TopP != nil {
230 tmp := float32(*call.TopP)
231 config.TopP = &tmp
232 }
233 if call.FrequencyPenalty != nil {
234 tmp := float32(*call.FrequencyPenalty)
235 config.FrequencyPenalty = &tmp
236 }
237 if call.PresencePenalty != nil {
238 tmp := float32(*call.PresencePenalty)
239 config.PresencePenalty = &tmp
240 }
241
242 if providerOptions.ThinkingConfig != nil {
243 config.ThinkingConfig = &genai.ThinkingConfig{}
244 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
245 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
246 }
247 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
248 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
249 config.ThinkingConfig.ThinkingBudget = &tmp
250 }
251 }
252 for _, safetySetting := range providerOptions.SafetySettings {
253 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
254 Category: genai.HarmCategory(safetySetting.Category),
255 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
256 })
257 }
258 if providerOptions.CachedContent != "" {
259 config.CachedContent = providerOptions.CachedContent
260 }
261
262 if len(call.Tools) > 0 {
263 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
264 config.ToolConfig = toolChoice
265 config.Tools = append(config.Tools, &genai.Tool{
266 FunctionDeclarations: tools,
267 })
268 warnings = append(warnings, toolWarnings...)
269 }
270
271 return config, content, warnings, nil
272}
273
274func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
275 var systemInstructions *genai.Content
276 var content []*genai.Content
277 var warnings []fantasy.CallWarning
278
279 finishedSystemBlock := false
280 for _, msg := range prompt {
281 switch msg.Role {
282 case fantasy.MessageRoleSystem:
283 if finishedSystemBlock {
284 // skip multiple system messages that are separated by user/assistant messages
285 // TODO: see if we need to send error here?
286 continue
287 }
288 finishedSystemBlock = true
289
290 var systemMessages []string
291 for _, part := range msg.Content {
292 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
293 if !ok || text.Text == "" {
294 continue
295 }
296 systemMessages = append(systemMessages, text.Text)
297 }
298 if len(systemMessages) > 0 {
299 systemInstructions = &genai.Content{
300 Parts: []*genai.Part{
301 {
302 Text: strings.Join(systemMessages, "\n"),
303 },
304 },
305 }
306 }
307 case fantasy.MessageRoleUser:
308 var parts []*genai.Part
309 for _, part := range msg.Content {
310 switch part.GetType() {
311 case fantasy.ContentTypeText:
312 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
313 if !ok || text.Text == "" {
314 continue
315 }
316 parts = append(parts, &genai.Part{
317 Text: text.Text,
318 })
319 case fantasy.ContentTypeFile:
320 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
321 if !ok {
322 continue
323 }
324 var encoded []byte
325 base64.StdEncoding.Encode(encoded, file.Data)
326 parts = append(parts, &genai.Part{
327 InlineData: &genai.Blob{
328 Data: encoded,
329 MIMEType: file.MediaType,
330 },
331 })
332 }
333 }
334 if len(parts) > 0 {
335 content = append(content, &genai.Content{
336 Role: genai.RoleUser,
337 Parts: parts,
338 })
339 }
340 case fantasy.MessageRoleAssistant:
341 var parts []*genai.Part
342 for _, part := range msg.Content {
343 switch part.GetType() {
344 case fantasy.ContentTypeText:
345 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
346 if !ok || text.Text == "" {
347 continue
348 }
349 parts = append(parts, &genai.Part{
350 Text: text.Text,
351 })
352 case fantasy.ContentTypeToolCall:
353 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
354 if !ok {
355 continue
356 }
357
358 var result map[string]any
359 err := json.Unmarshal([]byte(toolCall.Input), &result)
360 if err != nil {
361 continue
362 }
363 parts = append(parts, &genai.Part{
364 FunctionCall: &genai.FunctionCall{
365 ID: toolCall.ToolCallID,
366 Name: toolCall.ToolName,
367 Args: result,
368 },
369 })
370 }
371 }
372 if len(parts) > 0 {
373 content = append(content, &genai.Content{
374 Role: genai.RoleModel,
375 Parts: parts,
376 })
377 }
378 case fantasy.MessageRoleTool:
379 var parts []*genai.Part
380 for _, part := range msg.Content {
381 switch part.GetType() {
382 case fantasy.ContentTypeToolResult:
383 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
384 if !ok {
385 continue
386 }
387 var toolCall fantasy.ToolCallPart
388 for _, m := range prompt {
389 if m.Role == fantasy.MessageRoleAssistant {
390 for _, content := range m.Content {
391 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
392 if !ok {
393 continue
394 }
395 if tc.ToolCallID == result.ToolCallID {
396 toolCall = tc
397 break
398 }
399 }
400 }
401 }
402 switch result.Output.GetType() {
403 case fantasy.ToolResultContentTypeText:
404 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
405 if !ok {
406 continue
407 }
408 response := map[string]any{"result": content.Text}
409 parts = append(parts, &genai.Part{
410 FunctionResponse: &genai.FunctionResponse{
411 ID: result.ToolCallID,
412 Response: response,
413 Name: toolCall.ToolName,
414 },
415 })
416
417 case fantasy.ToolResultContentTypeError:
418 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
419 if !ok {
420 continue
421 }
422 response := map[string]any{"result": content.Error.Error()}
423 parts = append(parts, &genai.Part{
424 FunctionResponse: &genai.FunctionResponse{
425 ID: result.ToolCallID,
426 Response: response,
427 Name: toolCall.ToolName,
428 },
429 })
430 }
431 }
432 }
433 if len(parts) > 0 {
434 content = append(content, &genai.Content{
435 Role: genai.RoleUser,
436 Parts: parts,
437 })
438 }
439 default:
440 panic("unsupported message role: " + msg.Role)
441 }
442 }
443 return systemInstructions, content, warnings
444}
445
446// Generate implements fantasy.LanguageModel.
447func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
448 config, contents, warnings, err := g.prepareParams(call)
449 if err != nil {
450 return nil, err
451 }
452
453 lastMessage, history, ok := slice.Pop(contents)
454 if !ok {
455 return nil, errors.New("no messages to send")
456 }
457
458 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
459 if err != nil {
460 return nil, err
461 }
462
463 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
464 if err != nil {
465 return nil, err
466 }
467
468 return mapResponse(response, warnings)
469}
470
471// Model implements fantasy.LanguageModel.
472func (g *languageModel) Model() string {
473 return g.modelID
474}
475
476// Provider implements fantasy.LanguageModel.
477func (g *languageModel) Provider() string {
478 return g.provider
479}
480
481// Stream implements fantasy.LanguageModel.
482func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
483 config, contents, warnings, err := g.prepareParams(call)
484 if err != nil {
485 return nil, err
486 }
487
488 lastMessage, history, ok := slice.Pop(contents)
489 if !ok {
490 return nil, errors.New("no messages to send")
491 }
492
493 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
494 if err != nil {
495 return nil, err
496 }
497
498 return func(yield func(fantasy.StreamPart) bool) {
499 if len(warnings) > 0 {
500 if !yield(fantasy.StreamPart{
501 Type: fantasy.StreamPartTypeWarnings,
502 Warnings: warnings,
503 }) {
504 return
505 }
506 }
507
508 var currentContent string
509 var toolCalls []fantasy.ToolCallContent
510 var isActiveText bool
511 var isActiveReasoning bool
512 var blockCounter int
513 var currentTextBlockID string
514 var currentReasoningBlockID string
515 var usage fantasy.Usage
516 var lastFinishReason fantasy.FinishReason
517
518 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
519 if err != nil {
520 yield(fantasy.StreamPart{
521 Type: fantasy.StreamPartTypeError,
522 Error: err,
523 })
524 return
525 }
526
527 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
528 for _, part := range resp.Candidates[0].Content.Parts {
529 switch {
530 case part.Text != "":
531 delta := part.Text
532 if delta != "" {
533 // Check if this is a reasoning/thought part
534 if part.Thought {
535 // End any active text block before starting reasoning
536 if isActiveText {
537 isActiveText = false
538 if !yield(fantasy.StreamPart{
539 Type: fantasy.StreamPartTypeTextEnd,
540 ID: currentTextBlockID,
541 }) {
542 return
543 }
544 }
545
546 // Start new reasoning block if not already active
547 if !isActiveReasoning {
548 isActiveReasoning = true
549 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
550 blockCounter++
551 if !yield(fantasy.StreamPart{
552 Type: fantasy.StreamPartTypeReasoningStart,
553 ID: currentReasoningBlockID,
554 }) {
555 return
556 }
557 }
558
559 if !yield(fantasy.StreamPart{
560 Type: fantasy.StreamPartTypeReasoningDelta,
561 ID: currentReasoningBlockID,
562 Delta: delta,
563 }) {
564 return
565 }
566 } else {
567 // Regular text part
568 // End any active reasoning block before starting text
569 if isActiveReasoning {
570 isActiveReasoning = false
571 if !yield(fantasy.StreamPart{
572 Type: fantasy.StreamPartTypeReasoningEnd,
573 ID: currentReasoningBlockID,
574 }) {
575 return
576 }
577 }
578
579 // Start new text block if not already active
580 if !isActiveText {
581 isActiveText = true
582 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
583 blockCounter++
584 if !yield(fantasy.StreamPart{
585 Type: fantasy.StreamPartTypeTextStart,
586 ID: currentTextBlockID,
587 }) {
588 return
589 }
590 }
591
592 if !yield(fantasy.StreamPart{
593 Type: fantasy.StreamPartTypeTextDelta,
594 ID: currentTextBlockID,
595 Delta: delta,
596 }) {
597 return
598 }
599 currentContent += delta
600 }
601 }
602 case part.FunctionCall != nil:
603 // End any active text or reasoning blocks
604 if isActiveText {
605 isActiveText = false
606 if !yield(fantasy.StreamPart{
607 Type: fantasy.StreamPartTypeTextEnd,
608 ID: currentTextBlockID,
609 }) {
610 return
611 }
612 }
613 if isActiveReasoning {
614 isActiveReasoning = false
615 if !yield(fantasy.StreamPart{
616 Type: fantasy.StreamPartTypeReasoningEnd,
617 ID: currentReasoningBlockID,
618 }) {
619 return
620 }
621 }
622
623 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
624
625 args, err := json.Marshal(part.FunctionCall.Args)
626 if err != nil {
627 yield(fantasy.StreamPart{
628 Type: fantasy.StreamPartTypeError,
629 Error: err,
630 })
631 return
632 }
633
634 if !yield(fantasy.StreamPart{
635 Type: fantasy.StreamPartTypeToolInputStart,
636 ID: toolCallID,
637 ToolCallName: part.FunctionCall.Name,
638 }) {
639 return
640 }
641
642 if !yield(fantasy.StreamPart{
643 Type: fantasy.StreamPartTypeToolInputDelta,
644 ID: toolCallID,
645 Delta: string(args),
646 }) {
647 return
648 }
649
650 if !yield(fantasy.StreamPart{
651 Type: fantasy.StreamPartTypeToolInputEnd,
652 ID: toolCallID,
653 }) {
654 return
655 }
656
657 if !yield(fantasy.StreamPart{
658 Type: fantasy.StreamPartTypeToolCall,
659 ID: toolCallID,
660 ToolCallName: part.FunctionCall.Name,
661 ToolCallInput: string(args),
662 ProviderExecuted: false,
663 }) {
664 return
665 }
666
667 toolCalls = append(toolCalls, fantasy.ToolCallContent{
668 ToolCallID: toolCallID,
669 ToolName: part.FunctionCall.Name,
670 Input: string(args),
671 ProviderExecuted: false,
672 })
673 }
674 }
675 }
676
677 if resp.UsageMetadata != nil {
678 usage = mapUsage(resp.UsageMetadata)
679 }
680
681 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
682 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
683 }
684 }
685
686 // Close any open blocks before finishing
687 if isActiveText {
688 if !yield(fantasy.StreamPart{
689 Type: fantasy.StreamPartTypeTextEnd,
690 ID: currentTextBlockID,
691 }) {
692 return
693 }
694 }
695 if isActiveReasoning {
696 if !yield(fantasy.StreamPart{
697 Type: fantasy.StreamPartTypeReasoningEnd,
698 ID: currentReasoningBlockID,
699 }) {
700 return
701 }
702 }
703
704 finishReason := lastFinishReason
705 if len(toolCalls) > 0 {
706 finishReason = fantasy.FinishReasonToolCalls
707 } else if finishReason == "" {
708 finishReason = fantasy.FinishReasonStop
709 }
710
711 yield(fantasy.StreamPart{
712 Type: fantasy.StreamPartTypeFinish,
713 Usage: usage,
714 FinishReason: finishReason,
715 })
716 }, nil
717}
718
719func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
720 for _, tool := range tools {
721 if tool.GetType() == fantasy.ToolTypeFunction {
722 ft, ok := tool.(fantasy.FunctionTool)
723 if !ok {
724 continue
725 }
726
727 required := []string{}
728 var properties map[string]any
729 if props, ok := ft.InputSchema["properties"]; ok {
730 properties, _ = props.(map[string]any)
731 }
732 if req, ok := ft.InputSchema["required"]; ok {
733 if reqArr, ok := req.([]string); ok {
734 required = reqArr
735 }
736 }
737 declaration := &genai.FunctionDeclaration{
738 Name: ft.Name,
739 Description: ft.Description,
740 Parameters: &genai.Schema{
741 Type: genai.TypeObject,
742 Properties: convertSchemaProperties(properties),
743 Required: required,
744 },
745 }
746 googleTools = append(googleTools, declaration)
747 continue
748 }
749 // TODO: handle provider tool calls
750 warnings = append(warnings, fantasy.CallWarning{
751 Type: fantasy.CallWarningTypeUnsupportedTool,
752 Tool: tool,
753 Message: "tool is not supported",
754 })
755 }
756 if toolChoice == nil {
757 return googleTools, googleToolChoice, warnings
758 }
759 switch *toolChoice {
760 case fantasy.ToolChoiceAuto:
761 googleToolChoice = &genai.ToolConfig{
762 FunctionCallingConfig: &genai.FunctionCallingConfig{
763 Mode: genai.FunctionCallingConfigModeAuto,
764 },
765 }
766 case fantasy.ToolChoiceRequired:
767 googleToolChoice = &genai.ToolConfig{
768 FunctionCallingConfig: &genai.FunctionCallingConfig{
769 Mode: genai.FunctionCallingConfigModeAny,
770 },
771 }
772 case fantasy.ToolChoiceNone:
773 googleToolChoice = &genai.ToolConfig{
774 FunctionCallingConfig: &genai.FunctionCallingConfig{
775 Mode: genai.FunctionCallingConfigModeNone,
776 },
777 }
778 default:
779 googleToolChoice = &genai.ToolConfig{
780 FunctionCallingConfig: &genai.FunctionCallingConfig{
781 Mode: genai.FunctionCallingConfigModeAny,
782 AllowedFunctionNames: []string{
783 string(*toolChoice),
784 },
785 },
786 }
787 }
788 return googleTools, googleToolChoice, warnings
789}
790
791func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
792 properties := make(map[string]*genai.Schema)
793
794 for name, param := range parameters {
795 properties[name] = convertToSchema(param)
796 }
797
798 return properties
799}
800
801func convertToSchema(param any) *genai.Schema {
802 schema := &genai.Schema{Type: genai.TypeString}
803
804 paramMap, ok := param.(map[string]any)
805 if !ok {
806 return schema
807 }
808
809 if desc, ok := paramMap["description"].(string); ok {
810 schema.Description = desc
811 }
812
813 typeVal, hasType := paramMap["type"]
814 if !hasType {
815 return schema
816 }
817
818 typeStr, ok := typeVal.(string)
819 if !ok {
820 return schema
821 }
822
823 schema.Type = mapJSONTypeToGoogle(typeStr)
824
825 switch typeStr {
826 case "array":
827 schema.Items = processArrayItems(paramMap)
828 case "object":
829 if props, ok := paramMap["properties"].(map[string]any); ok {
830 schema.Properties = convertSchemaProperties(props)
831 }
832 }
833
834 return schema
835}
836
837func processArrayItems(paramMap map[string]any) *genai.Schema {
838 items, ok := paramMap["items"].(map[string]any)
839 if !ok {
840 return nil
841 }
842
843 return convertToSchema(items)
844}
845
846func mapJSONTypeToGoogle(jsonType string) genai.Type {
847 switch jsonType {
848 case "string":
849 return genai.TypeString
850 case "number":
851 return genai.TypeNumber
852 case "integer":
853 return genai.TypeInteger
854 case "boolean":
855 return genai.TypeBoolean
856 case "array":
857 return genai.TypeArray
858 case "object":
859 return genai.TypeObject
860 default:
861 return genai.TypeString // Default to string for unknown types
862 }
863}
864
865func mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
866 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
867 return nil, errors.New("no response from model")
868 }
869
870 var (
871 content []fantasy.Content
872 finishReason fantasy.FinishReason
873 hasToolCalls bool
874 candidate = response.Candidates[0]
875 )
876
877 for _, part := range candidate.Content.Parts {
878 switch {
879 case part.Text != "":
880 if part.Thought {
881 content = append(content, fantasy.ReasoningContent{Text: part.Text})
882 } else {
883 content = append(content, fantasy.TextContent{Text: part.Text})
884 }
885 case part.FunctionCall != nil:
886 input, err := json.Marshal(part.FunctionCall.Args)
887 if err != nil {
888 return nil, err
889 }
890 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
891 content = append(content, fantasy.ToolCallContent{
892 ToolCallID: toolCallID,
893 ToolName: part.FunctionCall.Name,
894 Input: string(input),
895 ProviderExecuted: false,
896 })
897 hasToolCalls = true
898 default:
899 // Silently skip unknown part types instead of erroring
900 // This allows for forward compatibility with new part types
901 }
902 }
903
904 if hasToolCalls {
905 finishReason = fantasy.FinishReasonToolCalls
906 } else {
907 finishReason = mapFinishReason(candidate.FinishReason)
908 }
909
910 return &fantasy.Response{
911 Content: content,
912 Usage: mapUsage(response.UsageMetadata),
913 FinishReason: finishReason,
914 Warnings: warnings,
915 }, nil
916}
917
918func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
919 switch reason {
920 case genai.FinishReasonStop:
921 return fantasy.FinishReasonStop
922 case genai.FinishReasonMaxTokens:
923 return fantasy.FinishReasonLength
924 case genai.FinishReasonSafety,
925 genai.FinishReasonBlocklist,
926 genai.FinishReasonProhibitedContent,
927 genai.FinishReasonSPII,
928 genai.FinishReasonImageSafety:
929 return fantasy.FinishReasonContentFilter
930 case genai.FinishReasonRecitation,
931 genai.FinishReasonLanguage,
932 genai.FinishReasonMalformedFunctionCall:
933 return fantasy.FinishReasonError
934 case genai.FinishReasonOther:
935 return fantasy.FinishReasonOther
936 default:
937 return fantasy.FinishReasonUnknown
938 }
939}
940
941func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
942 return fantasy.Usage{
943 InputTokens: int64(usage.ToolUsePromptTokenCount),
944 OutputTokens: int64(usage.CandidatesTokenCount),
945 TotalTokens: int64(usage.TotalTokenCount),
946 ReasoningTokens: int64(usage.ThoughtsTokenCount),
947 CacheCreationTokens: int64(usage.CachedContentTokenCount),
948 CacheReadTokens: 0,
949 }
950}