google.go

  1package google
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/base64"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"maps"
 11	"net/http"
 12	"strings"
 13
 14	"charm.land/fantasy"
 15	"charm.land/fantasy/anthropic"
 16	"cloud.google.com/go/auth"
 17	"github.com/charmbracelet/x/exp/slice"
 18	"github.com/google/uuid"
 19	"google.golang.org/genai"
 20)
 21
 22const Name = "google"
 23
 24type provider struct {
 25	options options
 26}
 27
 28type options struct {
 29	apiKey   string
 30	name     string
 31	baseURL  string
 32	headers  map[string]string
 33	client   *http.Client
 34	backend  genai.Backend
 35	project  string
 36	location string
 37	skipAuth bool
 38}
 39
 40type Option = func(*options)
 41
 42func New(opts ...Option) fantasy.Provider {
 43	options := options{
 44		headers: map[string]string{},
 45	}
 46	for _, o := range opts {
 47		o(&options)
 48	}
 49
 50	options.name = cmp.Or(options.name, Name)
 51
 52	return &provider{
 53		options: options,
 54	}
 55}
 56
 57func WithBaseURL(baseURL string) Option {
 58	return func(o *options) {
 59		o.baseURL = baseURL
 60	}
 61}
 62
 63func WithGeminiAPIKey(apiKey string) Option {
 64	return func(o *options) {
 65		o.backend = genai.BackendGeminiAPI
 66		o.apiKey = apiKey
 67		o.project = ""
 68		o.location = ""
 69	}
 70}
 71
 72func WithVertex(project, location string) Option {
 73	if project == "" || location == "" {
 74		panic("project and location must be provided")
 75	}
 76	return func(o *options) {
 77		o.backend = genai.BackendVertexAI
 78		o.apiKey = ""
 79		o.project = project
 80		o.location = location
 81	}
 82}
 83
 84func WithSkipAuth(skipAuth bool) Option {
 85	return func(o *options) {
 86		o.skipAuth = skipAuth
 87	}
 88}
 89
 90func WithName(name string) Option {
 91	return func(o *options) {
 92		o.name = name
 93	}
 94}
 95
 96func WithHeaders(headers map[string]string) Option {
 97	return func(o *options) {
 98		maps.Copy(o.headers, headers)
 99	}
100}
101
102func WithHTTPClient(client *http.Client) Option {
103	return func(o *options) {
104		o.client = client
105	}
106}
107
108func (*provider) Name() string {
109	return Name
110}
111
112type languageModel struct {
113	provider        string
114	modelID         string
115	client          *genai.Client
116	providerOptions options
117}
118
119// LanguageModel implements fantasy.Provider.
120func (g *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
121	if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
122		return anthropic.New(
123			anthropic.WithVertex(g.options.project, g.options.location),
124			anthropic.WithHTTPClient(g.options.client),
125			anthropic.WithSkipAuth(g.options.skipAuth),
126		).LanguageModel(modelID)
127	}
128
129	cc := &genai.ClientConfig{
130		HTTPClient: g.options.client,
131		Backend:    g.options.backend,
132		APIKey:     g.options.apiKey,
133		Project:    g.options.project,
134		Location:   g.options.location,
135	}
136	if g.options.skipAuth {
137		cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
138	}
139
140	if g.options.baseURL != "" || len(g.options.headers) > 0 {
141		headers := http.Header{}
142		for k, v := range g.options.headers {
143			headers.Add(k, v)
144		}
145		cc.HTTPOptions = genai.HTTPOptions{
146			BaseURL: g.options.baseURL,
147			Headers: headers,
148		}
149	}
150	client, err := genai.NewClient(context.Background(), cc)
151	if err != nil {
152		return nil, err
153	}
154	return &languageModel{
155		modelID:         modelID,
156		provider:        g.options.name,
157		providerOptions: g.options,
158		client:          client,
159	}, nil
160}
161
162func (a languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
163	config := &genai.GenerateContentConfig{}
164
165	providerOptions := &ProviderOptions{}
166	if v, ok := call.ProviderOptions[Name]; ok {
167		providerOptions, ok = v.(*ProviderOptions)
168		if !ok {
169			return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
170		}
171	}
172
173	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
174
175	if providerOptions.ThinkingConfig != nil {
176		if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
177			*providerOptions.ThinkingConfig.IncludeThoughts &&
178			strings.HasPrefix(a.provider, "google.vertex.") {
179			warnings = append(warnings, fantasy.CallWarning{
180				Type: fantasy.CallWarningTypeOther,
181				Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
182					"and might not be supported or could behave unexpectedly with the current Google provider " +
183					fmt.Sprintf("(%s)", a.provider),
184			})
185		}
186
187		if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
188			*providerOptions.ThinkingConfig.ThinkingBudget < 128 {
189			warnings = append(warnings, fantasy.CallWarning{
190				Type:    fantasy.CallWarningTypeOther,
191				Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
192			})
193			providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
194		}
195	}
196
197	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
198
199	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
200		if len(content) > 0 && content[0].Role == genai.RoleUser {
201			systemParts := []string{}
202			for _, sp := range systemInstructions.Parts {
203				systemParts = append(systemParts, sp.Text)
204			}
205			systemMsg := strings.Join(systemParts, "\n")
206			content[0].Parts = append([]*genai.Part{
207				{
208					Text: systemMsg + "\n\n",
209				},
210			}, content[0].Parts...)
211			systemInstructions = nil
212		}
213	}
214
215	config.SystemInstruction = systemInstructions
216
217	if call.MaxOutputTokens != nil {
218		config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
219	}
220
221	if call.Temperature != nil {
222		tmp := float32(*call.Temperature)
223		config.Temperature = &tmp
224	}
225	if call.TopK != nil {
226		tmp := float32(*call.TopK)
227		config.TopK = &tmp
228	}
229	if call.TopP != nil {
230		tmp := float32(*call.TopP)
231		config.TopP = &tmp
232	}
233	if call.FrequencyPenalty != nil {
234		tmp := float32(*call.FrequencyPenalty)
235		config.FrequencyPenalty = &tmp
236	}
237	if call.PresencePenalty != nil {
238		tmp := float32(*call.PresencePenalty)
239		config.PresencePenalty = &tmp
240	}
241
242	if providerOptions.ThinkingConfig != nil {
243		config.ThinkingConfig = &genai.ThinkingConfig{}
244		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
245			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
246		}
247		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
248			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
249			config.ThinkingConfig.ThinkingBudget = &tmp
250		}
251	}
252	for _, safetySetting := range providerOptions.SafetySettings {
253		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
254			Category:  genai.HarmCategory(safetySetting.Category),
255			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
256		})
257	}
258	if providerOptions.CachedContent != "" {
259		config.CachedContent = providerOptions.CachedContent
260	}
261
262	if len(call.Tools) > 0 {
263		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
264		config.ToolConfig = toolChoice
265		config.Tools = append(config.Tools, &genai.Tool{
266			FunctionDeclarations: tools,
267		})
268		warnings = append(warnings, toolWarnings...)
269	}
270
271	return config, content, warnings, nil
272}
273
274func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
275	var systemInstructions *genai.Content
276	var content []*genai.Content
277	var warnings []fantasy.CallWarning
278
279	finishedSystemBlock := false
280	for _, msg := range prompt {
281		switch msg.Role {
282		case fantasy.MessageRoleSystem:
283			if finishedSystemBlock {
284				// skip multiple system messages that are separated by user/assistant messages
285				// TODO: see if we need to send error here?
286				continue
287			}
288			finishedSystemBlock = true
289
290			var systemMessages []string
291			for _, part := range msg.Content {
292				text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
293				if !ok || text.Text == "" {
294					continue
295				}
296				systemMessages = append(systemMessages, text.Text)
297			}
298			if len(systemMessages) > 0 {
299				systemInstructions = &genai.Content{
300					Parts: []*genai.Part{
301						{
302							Text: strings.Join(systemMessages, "\n"),
303						},
304					},
305				}
306			}
307		case fantasy.MessageRoleUser:
308			var parts []*genai.Part
309			for _, part := range msg.Content {
310				switch part.GetType() {
311				case fantasy.ContentTypeText:
312					text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
313					if !ok || text.Text == "" {
314						continue
315					}
316					parts = append(parts, &genai.Part{
317						Text: text.Text,
318					})
319				case fantasy.ContentTypeFile:
320					file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
321					if !ok {
322						continue
323					}
324					var encoded []byte
325					base64.StdEncoding.Encode(encoded, file.Data)
326					parts = append(parts, &genai.Part{
327						InlineData: &genai.Blob{
328							Data:     encoded,
329							MIMEType: file.MediaType,
330						},
331					})
332				}
333			}
334			if len(parts) > 0 {
335				content = append(content, &genai.Content{
336					Role:  genai.RoleUser,
337					Parts: parts,
338				})
339			}
340		case fantasy.MessageRoleAssistant:
341			var parts []*genai.Part
342			for _, part := range msg.Content {
343				switch part.GetType() {
344				case fantasy.ContentTypeText:
345					text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
346					if !ok || text.Text == "" {
347						continue
348					}
349					parts = append(parts, &genai.Part{
350						Text: text.Text,
351					})
352				case fantasy.ContentTypeToolCall:
353					toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
354					if !ok {
355						continue
356					}
357
358					var result map[string]any
359					err := json.Unmarshal([]byte(toolCall.Input), &result)
360					if err != nil {
361						continue
362					}
363					parts = append(parts, &genai.Part{
364						FunctionCall: &genai.FunctionCall{
365							ID:   toolCall.ToolCallID,
366							Name: toolCall.ToolName,
367							Args: result,
368						},
369					})
370				}
371			}
372			if len(parts) > 0 {
373				content = append(content, &genai.Content{
374					Role:  genai.RoleModel,
375					Parts: parts,
376				})
377			}
378		case fantasy.MessageRoleTool:
379			var parts []*genai.Part
380			for _, part := range msg.Content {
381				switch part.GetType() {
382				case fantasy.ContentTypeToolResult:
383					result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
384					if !ok {
385						continue
386					}
387					var toolCall fantasy.ToolCallPart
388					for _, m := range prompt {
389						if m.Role == fantasy.MessageRoleAssistant {
390							for _, content := range m.Content {
391								tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
392								if !ok {
393									continue
394								}
395								if tc.ToolCallID == result.ToolCallID {
396									toolCall = tc
397									break
398								}
399							}
400						}
401					}
402					switch result.Output.GetType() {
403					case fantasy.ToolResultContentTypeText:
404						content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
405						if !ok {
406							continue
407						}
408						response := map[string]any{"result": content.Text}
409						parts = append(parts, &genai.Part{
410							FunctionResponse: &genai.FunctionResponse{
411								ID:       result.ToolCallID,
412								Response: response,
413								Name:     toolCall.ToolName,
414							},
415						})
416
417					case fantasy.ToolResultContentTypeError:
418						content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
419						if !ok {
420							continue
421						}
422						response := map[string]any{"result": content.Error.Error()}
423						parts = append(parts, &genai.Part{
424							FunctionResponse: &genai.FunctionResponse{
425								ID:       result.ToolCallID,
426								Response: response,
427								Name:     toolCall.ToolName,
428							},
429						})
430					}
431				}
432			}
433			if len(parts) > 0 {
434				content = append(content, &genai.Content{
435					Role:  genai.RoleUser,
436					Parts: parts,
437				})
438			}
439		default:
440			panic("unsupported message role: " + msg.Role)
441		}
442	}
443	return systemInstructions, content, warnings
444}
445
446// Generate implements fantasy.LanguageModel.
447func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
448	config, contents, warnings, err := g.prepareParams(call)
449	if err != nil {
450		return nil, err
451	}
452
453	lastMessage, history, ok := slice.Pop(contents)
454	if !ok {
455		return nil, errors.New("no messages to send")
456	}
457
458	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
459	if err != nil {
460		return nil, err
461	}
462
463	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
464	if err != nil {
465		return nil, err
466	}
467
468	return mapResponse(response, warnings)
469}
470
471// Model implements fantasy.LanguageModel.
472func (g *languageModel) Model() string {
473	return g.modelID
474}
475
476// Provider implements fantasy.LanguageModel.
477func (g *languageModel) Provider() string {
478	return g.provider
479}
480
481// Stream implements fantasy.LanguageModel.
482func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
483	config, contents, warnings, err := g.prepareParams(call)
484	if err != nil {
485		return nil, err
486	}
487
488	lastMessage, history, ok := slice.Pop(contents)
489	if !ok {
490		return nil, errors.New("no messages to send")
491	}
492
493	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
494	if err != nil {
495		return nil, err
496	}
497
498	return func(yield func(fantasy.StreamPart) bool) {
499		if len(warnings) > 0 {
500			if !yield(fantasy.StreamPart{
501				Type:     fantasy.StreamPartTypeWarnings,
502				Warnings: warnings,
503			}) {
504				return
505			}
506		}
507
508		var currentContent string
509		var toolCalls []fantasy.ToolCallContent
510		var isActiveText bool
511		var isActiveReasoning bool
512		var blockCounter int
513		var currentTextBlockID string
514		var currentReasoningBlockID string
515		var usage fantasy.Usage
516		var lastFinishReason fantasy.FinishReason
517
518		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
519			if err != nil {
520				yield(fantasy.StreamPart{
521					Type:  fantasy.StreamPartTypeError,
522					Error: err,
523				})
524				return
525			}
526
527			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
528				for _, part := range resp.Candidates[0].Content.Parts {
529					switch {
530					case part.Text != "":
531						delta := part.Text
532						if delta != "" {
533							// Check if this is a reasoning/thought part
534							if part.Thought {
535								// End any active text block before starting reasoning
536								if isActiveText {
537									isActiveText = false
538									if !yield(fantasy.StreamPart{
539										Type: fantasy.StreamPartTypeTextEnd,
540										ID:   currentTextBlockID,
541									}) {
542										return
543									}
544								}
545
546								// Start new reasoning block if not already active
547								if !isActiveReasoning {
548									isActiveReasoning = true
549									currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
550									blockCounter++
551									if !yield(fantasy.StreamPart{
552										Type: fantasy.StreamPartTypeReasoningStart,
553										ID:   currentReasoningBlockID,
554									}) {
555										return
556									}
557								}
558
559								if !yield(fantasy.StreamPart{
560									Type:  fantasy.StreamPartTypeReasoningDelta,
561									ID:    currentReasoningBlockID,
562									Delta: delta,
563								}) {
564									return
565								}
566							} else {
567								// Regular text part
568								// End any active reasoning block before starting text
569								if isActiveReasoning {
570									isActiveReasoning = false
571									if !yield(fantasy.StreamPart{
572										Type: fantasy.StreamPartTypeReasoningEnd,
573										ID:   currentReasoningBlockID,
574									}) {
575										return
576									}
577								}
578
579								// Start new text block if not already active
580								if !isActiveText {
581									isActiveText = true
582									currentTextBlockID = fmt.Sprintf("%d", blockCounter)
583									blockCounter++
584									if !yield(fantasy.StreamPart{
585										Type: fantasy.StreamPartTypeTextStart,
586										ID:   currentTextBlockID,
587									}) {
588										return
589									}
590								}
591
592								if !yield(fantasy.StreamPart{
593									Type:  fantasy.StreamPartTypeTextDelta,
594									ID:    currentTextBlockID,
595									Delta: delta,
596								}) {
597									return
598								}
599								currentContent += delta
600							}
601						}
602					case part.FunctionCall != nil:
603						// End any active text or reasoning blocks
604						if isActiveText {
605							isActiveText = false
606							if !yield(fantasy.StreamPart{
607								Type: fantasy.StreamPartTypeTextEnd,
608								ID:   currentTextBlockID,
609							}) {
610								return
611							}
612						}
613						if isActiveReasoning {
614							isActiveReasoning = false
615							if !yield(fantasy.StreamPart{
616								Type: fantasy.StreamPartTypeReasoningEnd,
617								ID:   currentReasoningBlockID,
618							}) {
619								return
620							}
621						}
622
623						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
624
625						args, err := json.Marshal(part.FunctionCall.Args)
626						if err != nil {
627							yield(fantasy.StreamPart{
628								Type:  fantasy.StreamPartTypeError,
629								Error: err,
630							})
631							return
632						}
633
634						if !yield(fantasy.StreamPart{
635							Type:         fantasy.StreamPartTypeToolInputStart,
636							ID:           toolCallID,
637							ToolCallName: part.FunctionCall.Name,
638						}) {
639							return
640						}
641
642						if !yield(fantasy.StreamPart{
643							Type:  fantasy.StreamPartTypeToolInputDelta,
644							ID:    toolCallID,
645							Delta: string(args),
646						}) {
647							return
648						}
649
650						if !yield(fantasy.StreamPart{
651							Type: fantasy.StreamPartTypeToolInputEnd,
652							ID:   toolCallID,
653						}) {
654							return
655						}
656
657						if !yield(fantasy.StreamPart{
658							Type:             fantasy.StreamPartTypeToolCall,
659							ID:               toolCallID,
660							ToolCallName:     part.FunctionCall.Name,
661							ToolCallInput:    string(args),
662							ProviderExecuted: false,
663						}) {
664							return
665						}
666
667						toolCalls = append(toolCalls, fantasy.ToolCallContent{
668							ToolCallID:       toolCallID,
669							ToolName:         part.FunctionCall.Name,
670							Input:            string(args),
671							ProviderExecuted: false,
672						})
673					}
674				}
675			}
676
677			if resp.UsageMetadata != nil {
678				usage = mapUsage(resp.UsageMetadata)
679			}
680
681			if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
682				lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
683			}
684		}
685
686		// Close any open blocks before finishing
687		if isActiveText {
688			if !yield(fantasy.StreamPart{
689				Type: fantasy.StreamPartTypeTextEnd,
690				ID:   currentTextBlockID,
691			}) {
692				return
693			}
694		}
695		if isActiveReasoning {
696			if !yield(fantasy.StreamPart{
697				Type: fantasy.StreamPartTypeReasoningEnd,
698				ID:   currentReasoningBlockID,
699			}) {
700				return
701			}
702		}
703
704		finishReason := lastFinishReason
705		if len(toolCalls) > 0 {
706			finishReason = fantasy.FinishReasonToolCalls
707		} else if finishReason == "" {
708			finishReason = fantasy.FinishReasonStop
709		}
710
711		yield(fantasy.StreamPart{
712			Type:         fantasy.StreamPartTypeFinish,
713			Usage:        usage,
714			FinishReason: finishReason,
715		})
716	}, nil
717}
718
719func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
720	for _, tool := range tools {
721		if tool.GetType() == fantasy.ToolTypeFunction {
722			ft, ok := tool.(fantasy.FunctionTool)
723			if !ok {
724				continue
725			}
726
727			required := []string{}
728			var properties map[string]any
729			if props, ok := ft.InputSchema["properties"]; ok {
730				properties, _ = props.(map[string]any)
731			}
732			if req, ok := ft.InputSchema["required"]; ok {
733				if reqArr, ok := req.([]string); ok {
734					required = reqArr
735				}
736			}
737			declaration := &genai.FunctionDeclaration{
738				Name:        ft.Name,
739				Description: ft.Description,
740				Parameters: &genai.Schema{
741					Type:       genai.TypeObject,
742					Properties: convertSchemaProperties(properties),
743					Required:   required,
744				},
745			}
746			googleTools = append(googleTools, declaration)
747			continue
748		}
749		// TODO: handle provider tool calls
750		warnings = append(warnings, fantasy.CallWarning{
751			Type:    fantasy.CallWarningTypeUnsupportedTool,
752			Tool:    tool,
753			Message: "tool is not supported",
754		})
755	}
756	if toolChoice == nil {
757		return googleTools, googleToolChoice, warnings
758	}
759	switch *toolChoice {
760	case fantasy.ToolChoiceAuto:
761		googleToolChoice = &genai.ToolConfig{
762			FunctionCallingConfig: &genai.FunctionCallingConfig{
763				Mode: genai.FunctionCallingConfigModeAuto,
764			},
765		}
766	case fantasy.ToolChoiceRequired:
767		googleToolChoice = &genai.ToolConfig{
768			FunctionCallingConfig: &genai.FunctionCallingConfig{
769				Mode: genai.FunctionCallingConfigModeAny,
770			},
771		}
772	case fantasy.ToolChoiceNone:
773		googleToolChoice = &genai.ToolConfig{
774			FunctionCallingConfig: &genai.FunctionCallingConfig{
775				Mode: genai.FunctionCallingConfigModeNone,
776			},
777		}
778	default:
779		googleToolChoice = &genai.ToolConfig{
780			FunctionCallingConfig: &genai.FunctionCallingConfig{
781				Mode: genai.FunctionCallingConfigModeAny,
782				AllowedFunctionNames: []string{
783					string(*toolChoice),
784				},
785			},
786		}
787	}
788	return googleTools, googleToolChoice, warnings
789}
790
791func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
792	properties := make(map[string]*genai.Schema)
793
794	for name, param := range parameters {
795		properties[name] = convertToSchema(param)
796	}
797
798	return properties
799}
800
801func convertToSchema(param any) *genai.Schema {
802	schema := &genai.Schema{Type: genai.TypeString}
803
804	paramMap, ok := param.(map[string]any)
805	if !ok {
806		return schema
807	}
808
809	if desc, ok := paramMap["description"].(string); ok {
810		schema.Description = desc
811	}
812
813	typeVal, hasType := paramMap["type"]
814	if !hasType {
815		return schema
816	}
817
818	typeStr, ok := typeVal.(string)
819	if !ok {
820		return schema
821	}
822
823	schema.Type = mapJSONTypeToGoogle(typeStr)
824
825	switch typeStr {
826	case "array":
827		schema.Items = processArrayItems(paramMap)
828	case "object":
829		if props, ok := paramMap["properties"].(map[string]any); ok {
830			schema.Properties = convertSchemaProperties(props)
831		}
832	}
833
834	return schema
835}
836
837func processArrayItems(paramMap map[string]any) *genai.Schema {
838	items, ok := paramMap["items"].(map[string]any)
839	if !ok {
840		return nil
841	}
842
843	return convertToSchema(items)
844}
845
846func mapJSONTypeToGoogle(jsonType string) genai.Type {
847	switch jsonType {
848	case "string":
849		return genai.TypeString
850	case "number":
851		return genai.TypeNumber
852	case "integer":
853		return genai.TypeInteger
854	case "boolean":
855		return genai.TypeBoolean
856	case "array":
857		return genai.TypeArray
858	case "object":
859		return genai.TypeObject
860	default:
861		return genai.TypeString // Default to string for unknown types
862	}
863}
864
865func mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
866	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
867		return nil, errors.New("no response from model")
868	}
869
870	var (
871		content      []fantasy.Content
872		finishReason fantasy.FinishReason
873		hasToolCalls bool
874		candidate    = response.Candidates[0]
875	)
876
877	for _, part := range candidate.Content.Parts {
878		switch {
879		case part.Text != "":
880			if part.Thought {
881				content = append(content, fantasy.ReasoningContent{Text: part.Text})
882			} else {
883				content = append(content, fantasy.TextContent{Text: part.Text})
884			}
885		case part.FunctionCall != nil:
886			input, err := json.Marshal(part.FunctionCall.Args)
887			if err != nil {
888				return nil, err
889			}
890			toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
891			content = append(content, fantasy.ToolCallContent{
892				ToolCallID:       toolCallID,
893				ToolName:         part.FunctionCall.Name,
894				Input:            string(input),
895				ProviderExecuted: false,
896			})
897			hasToolCalls = true
898		default:
899			// Silently skip unknown part types instead of erroring
900			// This allows for forward compatibility with new part types
901		}
902	}
903
904	if hasToolCalls {
905		finishReason = fantasy.FinishReasonToolCalls
906	} else {
907		finishReason = mapFinishReason(candidate.FinishReason)
908	}
909
910	return &fantasy.Response{
911		Content:      content,
912		Usage:        mapUsage(response.UsageMetadata),
913		FinishReason: finishReason,
914		Warnings:     warnings,
915	}, nil
916}
917
918func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
919	switch reason {
920	case genai.FinishReasonStop:
921		return fantasy.FinishReasonStop
922	case genai.FinishReasonMaxTokens:
923		return fantasy.FinishReasonLength
924	case genai.FinishReasonSafety,
925		genai.FinishReasonBlocklist,
926		genai.FinishReasonProhibitedContent,
927		genai.FinishReasonSPII,
928		genai.FinishReasonImageSafety:
929		return fantasy.FinishReasonContentFilter
930	case genai.FinishReasonRecitation,
931		genai.FinishReasonLanguage,
932		genai.FinishReasonMalformedFunctionCall:
933		return fantasy.FinishReasonError
934	case genai.FinishReasonOther:
935		return fantasy.FinishReasonOther
936	default:
937		return fantasy.FinishReasonUnknown
938	}
939}
940
941func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
942	return fantasy.Usage{
943		InputTokens:         int64(usage.ToolUsePromptTokenCount),
944		OutputTokens:        int64(usage.CandidatesTokenCount),
945		TotalTokens:         int64(usage.TotalTokenCount),
946		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
947		CacheCreationTokens: int64(usage.CachedContentTokenCount),
948		CacheReadTokens:     0,
949	}
950}