google.go

  1package google
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/base64"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"maps"
 11	"net/http"
 12	"strings"
 13
 14	"cloud.google.com/go/auth"
 15	"github.com/charmbracelet/fantasy/ai"
 16	"github.com/charmbracelet/fantasy/anthropic"
 17	"github.com/charmbracelet/x/exp/slice"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22const Name = "google"
 23
 24type provider struct {
 25	options options
 26}
 27
 28type options struct {
 29	apiKey   string
 30	name     string
 31	headers  map[string]string
 32	client   *http.Client
 33	backend  genai.Backend
 34	project  string
 35	location string
 36	skipAuth bool
 37}
 38
 39type Option = func(*options)
 40
 41func New(opts ...Option) ai.Provider {
 42	options := options{
 43		headers: map[string]string{},
 44	}
 45	for _, o := range opts {
 46		o(&options)
 47	}
 48
 49	options.name = cmp.Or(options.name, Name)
 50
 51	return &provider{
 52		options: options,
 53	}
 54}
 55
 56func WithGeminiAPIKey(apiKey string) Option {
 57	return func(o *options) {
 58		o.backend = genai.BackendGeminiAPI
 59		o.apiKey = apiKey
 60		o.project = ""
 61		o.location = ""
 62	}
 63}
 64
 65func WithVertex(project, location string) Option {
 66	if project == "" || location == "" {
 67		panic("project and location must be provided")
 68	}
 69	return func(o *options) {
 70		o.backend = genai.BackendVertexAI
 71		o.apiKey = ""
 72		o.project = project
 73		o.location = location
 74	}
 75}
 76
 77func WithSkipAuth(skipAuth bool) Option {
 78	return func(o *options) {
 79		o.skipAuth = skipAuth
 80	}
 81}
 82
 83func WithName(name string) Option {
 84	return func(o *options) {
 85		o.name = name
 86	}
 87}
 88
 89func WithHeaders(headers map[string]string) Option {
 90	return func(o *options) {
 91		maps.Copy(o.headers, headers)
 92	}
 93}
 94
 95func WithHTTPClient(client *http.Client) Option {
 96	return func(o *options) {
 97		o.client = client
 98	}
 99}
100
101func (*provider) Name() string {
102	return Name
103}
104
105func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
106	var options ProviderOptions
107	if err := ai.ParseOptions(data, &options); err != nil {
108		return nil, err
109	}
110	return &options, nil
111}
112
113type languageModel struct {
114	provider        string
115	modelID         string
116	client          *genai.Client
117	providerOptions options
118}
119
120// LanguageModel implements ai.Provider.
121func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
122	if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
123		return anthropic.New(
124			anthropic.WithVertex(g.options.project, g.options.location),
125			anthropic.WithHTTPClient(g.options.client),
126			anthropic.WithSkipGoogleAuth(g.options.skipAuth),
127		).LanguageModel(modelID)
128	}
129
130	cc := &genai.ClientConfig{
131		HTTPClient: g.options.client,
132		Backend:    g.options.backend,
133		APIKey:     g.options.apiKey,
134		Project:    g.options.project,
135		Location:   g.options.location,
136	}
137	if g.options.skipAuth {
138		cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
139	}
140	client, err := genai.NewClient(context.Background(), cc)
141	if err != nil {
142		return nil, err
143	}
144	return &languageModel{
145		modelID:         modelID,
146		provider:        g.options.name,
147		providerOptions: g.options,
148		client:          client,
149	}, nil
150}
151
152func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
153	config := &genai.GenerateContentConfig{}
154
155	providerOptions := &ProviderOptions{}
156	if v, ok := call.ProviderOptions[Name]; ok {
157		providerOptions, ok = v.(*ProviderOptions)
158		if !ok {
159			return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
160		}
161	}
162
163	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
164
165	if providerOptions.ThinkingConfig != nil {
166		if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
167			*providerOptions.ThinkingConfig.IncludeThoughts &&
168			strings.HasPrefix(a.provider, "google.vertex.") {
169			warnings = append(warnings, ai.CallWarning{
170				Type: ai.CallWarningTypeOther,
171				Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
172					"and might not be supported or could behave unexpectedly with the current Google provider " +
173					fmt.Sprintf("(%s)", a.provider),
174			})
175		}
176
177		if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
178			*providerOptions.ThinkingConfig.ThinkingBudget < 128 {
179			warnings = append(warnings, ai.CallWarning{
180				Type:    ai.CallWarningTypeOther,
181				Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
182			})
183			providerOptions.ThinkingConfig.ThinkingBudget = ai.IntOption(128)
184		}
185	}
186
187	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
188
189	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
190		if len(content) > 0 && content[0].Role == genai.RoleUser {
191			systemParts := []string{}
192			for _, sp := range systemInstructions.Parts {
193				systemParts = append(systemParts, sp.Text)
194			}
195			systemMsg := strings.Join(systemParts, "\n")
196			content[0].Parts = append([]*genai.Part{
197				{
198					Text: systemMsg + "\n\n",
199				},
200			}, content[0].Parts...)
201			systemInstructions = nil
202		}
203	}
204
205	config.SystemInstruction = systemInstructions
206
207	if call.MaxOutputTokens != nil {
208		config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
209	}
210
211	if call.Temperature != nil {
212		tmp := float32(*call.Temperature)
213		config.Temperature = &tmp
214	}
215	if call.TopK != nil {
216		tmp := float32(*call.TopK)
217		config.TopK = &tmp
218	}
219	if call.TopP != nil {
220		tmp := float32(*call.TopP)
221		config.TopP = &tmp
222	}
223	if call.FrequencyPenalty != nil {
224		tmp := float32(*call.FrequencyPenalty)
225		config.FrequencyPenalty = &tmp
226	}
227	if call.PresencePenalty != nil {
228		tmp := float32(*call.PresencePenalty)
229		config.PresencePenalty = &tmp
230	}
231
232	if providerOptions.ThinkingConfig != nil {
233		config.ThinkingConfig = &genai.ThinkingConfig{}
234		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
235			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
236		}
237		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
238			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
239			config.ThinkingConfig.ThinkingBudget = &tmp
240		}
241	}
242	for _, safetySetting := range providerOptions.SafetySettings {
243		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
244			Category:  genai.HarmCategory(safetySetting.Category),
245			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
246		})
247	}
248	if providerOptions.CachedContent != "" {
249		config.CachedContent = providerOptions.CachedContent
250	}
251
252	if len(call.Tools) > 0 {
253		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
254		config.ToolConfig = toolChoice
255		config.Tools = append(config.Tools, &genai.Tool{
256			FunctionDeclarations: tools,
257		})
258		warnings = append(warnings, toolWarnings...)
259	}
260
261	return config, content, warnings, nil
262}
263
264func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
265	var systemInstructions *genai.Content
266	var content []*genai.Content
267	var warnings []ai.CallWarning
268
269	finishedSystemBlock := false
270	for _, msg := range prompt {
271		switch msg.Role {
272		case ai.MessageRoleSystem:
273			if finishedSystemBlock {
274				// skip multiple system messages that are separated by user/assistant messages
275				// TODO: see if we need to send error here?
276				continue
277			}
278			finishedSystemBlock = true
279
280			var systemMessages []string
281			for _, part := range msg.Content {
282				text, ok := ai.AsMessagePart[ai.TextPart](part)
283				if !ok || text.Text == "" {
284					continue
285				}
286				systemMessages = append(systemMessages, text.Text)
287			}
288			if len(systemMessages) > 0 {
289				systemInstructions = &genai.Content{
290					Parts: []*genai.Part{
291						{
292							Text: strings.Join(systemMessages, "\n"),
293						},
294					},
295				}
296			}
297		case ai.MessageRoleUser:
298			var parts []*genai.Part
299			for _, part := range msg.Content {
300				switch part.GetType() {
301				case ai.ContentTypeText:
302					text, ok := ai.AsMessagePart[ai.TextPart](part)
303					if !ok || text.Text == "" {
304						continue
305					}
306					parts = append(parts, &genai.Part{
307						Text: text.Text,
308					})
309				case ai.ContentTypeFile:
310					file, ok := ai.AsMessagePart[ai.FilePart](part)
311					if !ok {
312						continue
313					}
314					var encoded []byte
315					base64.StdEncoding.Encode(encoded, file.Data)
316					parts = append(parts, &genai.Part{
317						InlineData: &genai.Blob{
318							Data:     encoded,
319							MIMEType: file.MediaType,
320						},
321					})
322				}
323			}
324			if len(parts) > 0 {
325				content = append(content, &genai.Content{
326					Role:  genai.RoleUser,
327					Parts: parts,
328				})
329			}
330		case ai.MessageRoleAssistant:
331			var parts []*genai.Part
332			for _, part := range msg.Content {
333				switch part.GetType() {
334				case ai.ContentTypeText:
335					text, ok := ai.AsMessagePart[ai.TextPart](part)
336					if !ok || text.Text == "" {
337						continue
338					}
339					parts = append(parts, &genai.Part{
340						Text: text.Text,
341					})
342				case ai.ContentTypeToolCall:
343					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
344					if !ok {
345						continue
346					}
347
348					var result map[string]any
349					err := json.Unmarshal([]byte(toolCall.Input), &result)
350					if err != nil {
351						continue
352					}
353					parts = append(parts, &genai.Part{
354						FunctionCall: &genai.FunctionCall{
355							ID:   toolCall.ToolCallID,
356							Name: toolCall.ToolName,
357							Args: result,
358						},
359					})
360				}
361			}
362			if len(parts) > 0 {
363				content = append(content, &genai.Content{
364					Role:  genai.RoleModel,
365					Parts: parts,
366				})
367			}
368		case ai.MessageRoleTool:
369			var parts []*genai.Part
370			for _, part := range msg.Content {
371				switch part.GetType() {
372				case ai.ContentTypeToolResult:
373					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
374					if !ok {
375						continue
376					}
377					var toolCall ai.ToolCallPart
378					for _, m := range prompt {
379						if m.Role == ai.MessageRoleAssistant {
380							for _, content := range m.Content {
381								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
382								if !ok {
383									continue
384								}
385								if tc.ToolCallID == result.ToolCallID {
386									toolCall = tc
387									break
388								}
389							}
390						}
391					}
392					switch result.Output.GetType() {
393					case ai.ToolResultContentTypeText:
394						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
395						if !ok {
396							continue
397						}
398						response := map[string]any{"result": content.Text}
399						parts = append(parts, &genai.Part{
400							FunctionResponse: &genai.FunctionResponse{
401								ID:       result.ToolCallID,
402								Response: response,
403								Name:     toolCall.ToolName,
404							},
405						})
406
407					case ai.ToolResultContentTypeError:
408						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
409						if !ok {
410							continue
411						}
412						response := map[string]any{"result": content.Error.Error()}
413						parts = append(parts, &genai.Part{
414							FunctionResponse: &genai.FunctionResponse{
415								ID:       result.ToolCallID,
416								Response: response,
417								Name:     toolCall.ToolName,
418							},
419						})
420					}
421				}
422			}
423			if len(parts) > 0 {
424				content = append(content, &genai.Content{
425					Role:  genai.RoleUser,
426					Parts: parts,
427				})
428			}
429		default:
430			panic("unsupported message role: " + msg.Role)
431		}
432	}
433	return systemInstructions, content, warnings
434}
435
436// Generate implements ai.LanguageModel.
437func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
438	config, contents, warnings, err := g.prepareParams(call)
439	if err != nil {
440		return nil, err
441	}
442
443	lastMessage, history, ok := slice.Pop(contents)
444	if !ok {
445		return nil, errors.New("no messages to send")
446	}
447
448	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
449	if err != nil {
450		return nil, err
451	}
452
453	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
454	if err != nil {
455		return nil, err
456	}
457
458	return mapResponse(response, warnings)
459}
460
461// Model implements ai.LanguageModel.
462func (g *languageModel) Model() string {
463	return g.modelID
464}
465
466// Provider implements ai.LanguageModel.
467func (g *languageModel) Provider() string {
468	return g.provider
469}
470
471// Stream implements ai.LanguageModel.
472func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
473	config, contents, warnings, err := g.prepareParams(call)
474	if err != nil {
475		return nil, err
476	}
477
478	lastMessage, history, ok := slice.Pop(contents)
479	if !ok {
480		return nil, errors.New("no messages to send")
481	}
482
483	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
484	if err != nil {
485		return nil, err
486	}
487
488	return func(yield func(ai.StreamPart) bool) {
489		if len(warnings) > 0 {
490			if !yield(ai.StreamPart{
491				Type:     ai.StreamPartTypeWarnings,
492				Warnings: warnings,
493			}) {
494				return
495			}
496		}
497
498		var currentContent string
499		var toolCalls []ai.ToolCallContent
500		var isActiveText bool
501		var isActiveReasoning bool
502		var blockCounter int
503		var currentTextBlockID string
504		var currentReasoningBlockID string
505		var usage ai.Usage
506		var lastFinishReason ai.FinishReason
507
508		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
509			if err != nil {
510				yield(ai.StreamPart{
511					Type:  ai.StreamPartTypeError,
512					Error: err,
513				})
514				return
515			}
516
517			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
518				for _, part := range resp.Candidates[0].Content.Parts {
519					switch {
520					case part.Text != "":
521						delta := part.Text
522						if delta != "" {
523							// Check if this is a reasoning/thought part
524							if part.Thought {
525								// End any active text block before starting reasoning
526								if isActiveText {
527									isActiveText = false
528									if !yield(ai.StreamPart{
529										Type: ai.StreamPartTypeTextEnd,
530										ID:   currentTextBlockID,
531									}) {
532										return
533									}
534								}
535
536								// Start new reasoning block if not already active
537								if !isActiveReasoning {
538									isActiveReasoning = true
539									currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
540									blockCounter++
541									if !yield(ai.StreamPart{
542										Type: ai.StreamPartTypeReasoningStart,
543										ID:   currentReasoningBlockID,
544									}) {
545										return
546									}
547								}
548
549								if !yield(ai.StreamPart{
550									Type:  ai.StreamPartTypeReasoningDelta,
551									ID:    currentReasoningBlockID,
552									Delta: delta,
553								}) {
554									return
555								}
556							} else {
557								// Regular text part
558								// End any active reasoning block before starting text
559								if isActiveReasoning {
560									isActiveReasoning = false
561									if !yield(ai.StreamPart{
562										Type: ai.StreamPartTypeReasoningEnd,
563										ID:   currentReasoningBlockID,
564									}) {
565										return
566									}
567								}
568
569								// Start new text block if not already active
570								if !isActiveText {
571									isActiveText = true
572									currentTextBlockID = fmt.Sprintf("%d", blockCounter)
573									blockCounter++
574									if !yield(ai.StreamPart{
575										Type: ai.StreamPartTypeTextStart,
576										ID:   currentTextBlockID,
577									}) {
578										return
579									}
580								}
581
582								if !yield(ai.StreamPart{
583									Type:  ai.StreamPartTypeTextDelta,
584									ID:    currentTextBlockID,
585									Delta: delta,
586								}) {
587									return
588								}
589								currentContent += delta
590							}
591						}
592					case part.FunctionCall != nil:
593						// End any active text or reasoning blocks
594						if isActiveText {
595							isActiveText = false
596							if !yield(ai.StreamPart{
597								Type: ai.StreamPartTypeTextEnd,
598								ID:   currentTextBlockID,
599							}) {
600								return
601							}
602						}
603						if isActiveReasoning {
604							isActiveReasoning = false
605							if !yield(ai.StreamPart{
606								Type: ai.StreamPartTypeReasoningEnd,
607								ID:   currentReasoningBlockID,
608							}) {
609								return
610							}
611						}
612
613						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
614
615						args, err := json.Marshal(part.FunctionCall.Args)
616						if err != nil {
617							yield(ai.StreamPart{
618								Type:  ai.StreamPartTypeError,
619								Error: err,
620							})
621							return
622						}
623
624						if !yield(ai.StreamPart{
625							Type:         ai.StreamPartTypeToolInputStart,
626							ID:           toolCallID,
627							ToolCallName: part.FunctionCall.Name,
628						}) {
629							return
630						}
631
632						if !yield(ai.StreamPart{
633							Type:  ai.StreamPartTypeToolInputDelta,
634							ID:    toolCallID,
635							Delta: string(args),
636						}) {
637							return
638						}
639
640						if !yield(ai.StreamPart{
641							Type: ai.StreamPartTypeToolInputEnd,
642							ID:   toolCallID,
643						}) {
644							return
645						}
646
647						if !yield(ai.StreamPart{
648							Type:             ai.StreamPartTypeToolCall,
649							ID:               toolCallID,
650							ToolCallName:     part.FunctionCall.Name,
651							ToolCallInput:    string(args),
652							ProviderExecuted: false,
653						}) {
654							return
655						}
656
657						toolCalls = append(toolCalls, ai.ToolCallContent{
658							ToolCallID:       toolCallID,
659							ToolName:         part.FunctionCall.Name,
660							Input:            string(args),
661							ProviderExecuted: false,
662						})
663					}
664				}
665			}
666
667			if resp.UsageMetadata != nil {
668				usage = mapUsage(resp.UsageMetadata)
669			}
670
671			if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
672				lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
673			}
674		}
675
676		// Close any open blocks before finishing
677		if isActiveText {
678			if !yield(ai.StreamPart{
679				Type: ai.StreamPartTypeTextEnd,
680				ID:   currentTextBlockID,
681			}) {
682				return
683			}
684		}
685		if isActiveReasoning {
686			if !yield(ai.StreamPart{
687				Type: ai.StreamPartTypeReasoningEnd,
688				ID:   currentReasoningBlockID,
689			}) {
690				return
691			}
692		}
693
694		finishReason := lastFinishReason
695		if len(toolCalls) > 0 {
696			finishReason = ai.FinishReasonToolCalls
697		} else if finishReason == "" {
698			finishReason = ai.FinishReasonStop
699		}
700
701		yield(ai.StreamPart{
702			Type:         ai.StreamPartTypeFinish,
703			Usage:        usage,
704			FinishReason: finishReason,
705		})
706	}, nil
707}
708
709func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
710	for _, tool := range tools {
711		if tool.GetType() == ai.ToolTypeFunction {
712			ft, ok := tool.(ai.FunctionTool)
713			if !ok {
714				continue
715			}
716
717			required := []string{}
718			var properties map[string]any
719			if props, ok := ft.InputSchema["properties"]; ok {
720				properties, _ = props.(map[string]any)
721			}
722			if req, ok := ft.InputSchema["required"]; ok {
723				if reqArr, ok := req.([]string); ok {
724					required = reqArr
725				}
726			}
727			declaration := &genai.FunctionDeclaration{
728				Name:        ft.Name,
729				Description: ft.Description,
730				Parameters: &genai.Schema{
731					Type:       genai.TypeObject,
732					Properties: convertSchemaProperties(properties),
733					Required:   required,
734				},
735			}
736			googleTools = append(googleTools, declaration)
737			continue
738		}
739		// TODO: handle provider tool calls
740		warnings = append(warnings, ai.CallWarning{
741			Type:    ai.CallWarningTypeUnsupportedTool,
742			Tool:    tool,
743			Message: "tool is not supported",
744		})
745	}
746	if toolChoice == nil {
747		return googleTools, googleToolChoice, warnings
748	}
749	switch *toolChoice {
750	case ai.ToolChoiceAuto:
751		googleToolChoice = &genai.ToolConfig{
752			FunctionCallingConfig: &genai.FunctionCallingConfig{
753				Mode: genai.FunctionCallingConfigModeAuto,
754			},
755		}
756	case ai.ToolChoiceRequired:
757		googleToolChoice = &genai.ToolConfig{
758			FunctionCallingConfig: &genai.FunctionCallingConfig{
759				Mode: genai.FunctionCallingConfigModeAny,
760			},
761		}
762	case ai.ToolChoiceNone:
763		googleToolChoice = &genai.ToolConfig{
764			FunctionCallingConfig: &genai.FunctionCallingConfig{
765				Mode: genai.FunctionCallingConfigModeNone,
766			},
767		}
768	default:
769		googleToolChoice = &genai.ToolConfig{
770			FunctionCallingConfig: &genai.FunctionCallingConfig{
771				Mode: genai.FunctionCallingConfigModeAny,
772				AllowedFunctionNames: []string{
773					string(*toolChoice),
774				},
775			},
776		}
777	}
778	return googleTools, googleToolChoice, warnings
779}
780
781func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
782	properties := make(map[string]*genai.Schema)
783
784	for name, param := range parameters {
785		properties[name] = convertToSchema(param)
786	}
787
788	return properties
789}
790
791func convertToSchema(param any) *genai.Schema {
792	schema := &genai.Schema{Type: genai.TypeString}
793
794	paramMap, ok := param.(map[string]any)
795	if !ok {
796		return schema
797	}
798
799	if desc, ok := paramMap["description"].(string); ok {
800		schema.Description = desc
801	}
802
803	typeVal, hasType := paramMap["type"]
804	if !hasType {
805		return schema
806	}
807
808	typeStr, ok := typeVal.(string)
809	if !ok {
810		return schema
811	}
812
813	schema.Type = mapJSONTypeToGoogle(typeStr)
814
815	switch typeStr {
816	case "array":
817		schema.Items = processArrayItems(paramMap)
818	case "object":
819		if props, ok := paramMap["properties"].(map[string]any); ok {
820			schema.Properties = convertSchemaProperties(props)
821		}
822	}
823
824	return schema
825}
826
827func processArrayItems(paramMap map[string]any) *genai.Schema {
828	items, ok := paramMap["items"].(map[string]any)
829	if !ok {
830		return nil
831	}
832
833	return convertToSchema(items)
834}
835
836func mapJSONTypeToGoogle(jsonType string) genai.Type {
837	switch jsonType {
838	case "string":
839		return genai.TypeString
840	case "number":
841		return genai.TypeNumber
842	case "integer":
843		return genai.TypeInteger
844	case "boolean":
845		return genai.TypeBoolean
846	case "array":
847		return genai.TypeArray
848	case "object":
849		return genai.TypeObject
850	default:
851		return genai.TypeString // Default to string for unknown types
852	}
853}
854
855func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
856	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
857		return nil, errors.New("no response from model")
858	}
859
860	var (
861		content      []ai.Content
862		finishReason ai.FinishReason
863		hasToolCalls bool
864		candidate    = response.Candidates[0]
865	)
866
867	for _, part := range candidate.Content.Parts {
868		switch {
869		case part.Text != "":
870			if part.Thought {
871				content = append(content, ai.ReasoningContent{Text: part.Text})
872			} else {
873				content = append(content, ai.TextContent{Text: part.Text})
874			}
875		case part.FunctionCall != nil:
876			input, err := json.Marshal(part.FunctionCall.Args)
877			if err != nil {
878				return nil, err
879			}
880			toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
881			content = append(content, ai.ToolCallContent{
882				ToolCallID:       toolCallID,
883				ToolName:         part.FunctionCall.Name,
884				Input:            string(input),
885				ProviderExecuted: false,
886			})
887			hasToolCalls = true
888		default:
889			// Silently skip unknown part types instead of erroring
890			// This allows for forward compatibility with new part types
891		}
892	}
893
894	if hasToolCalls {
895		finishReason = ai.FinishReasonToolCalls
896	} else {
897		finishReason = mapFinishReason(candidate.FinishReason)
898	}
899
900	return &ai.Response{
901		Content:      content,
902		Usage:        mapUsage(response.UsageMetadata),
903		FinishReason: finishReason,
904		Warnings:     warnings,
905	}, nil
906}
907
908func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
909	switch reason {
910	case genai.FinishReasonStop:
911		return ai.FinishReasonStop
912	case genai.FinishReasonMaxTokens:
913		return ai.FinishReasonLength
914	case genai.FinishReasonSafety,
915		genai.FinishReasonBlocklist,
916		genai.FinishReasonProhibitedContent,
917		genai.FinishReasonSPII,
918		genai.FinishReasonImageSafety:
919		return ai.FinishReasonContentFilter
920	case genai.FinishReasonRecitation,
921		genai.FinishReasonLanguage,
922		genai.FinishReasonMalformedFunctionCall:
923		return ai.FinishReasonError
924	case genai.FinishReasonOther:
925		return ai.FinishReasonOther
926	default:
927		return ai.FinishReasonUnknown
928	}
929}
930
931func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
932	return ai.Usage{
933		InputTokens:         int64(usage.ToolUsePromptTokenCount),
934		OutputTokens:        int64(usage.CandidatesTokenCount),
935		TotalTokens:         int64(usage.TotalTokenCount),
936		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
937		CacheCreationTokens: int64(usage.CachedContentTokenCount),
938		CacheReadTokens:     0,
939	}
940}