1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "cloud.google.com/go/auth"
16 "github.com/charmbracelet/go-genai"
17 "github.com/charmbracelet/x/exp/slice"
18 "github.com/google/uuid"
19)
20
21// Name is the name of the Google provider.
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28type options struct {
29 apiKey string
30 name string
31 baseURL string
32 headers map[string]string
33 client *http.Client
34 backend genai.Backend
35 project string
36 location string
37 skipAuth bool
38}
39
40// Option defines a function that configures Google provider options.
41type Option = func(*options)
42
43// New creates a new Google provider with the given options.
44func New(opts ...Option) (fantasy.Provider, error) {
45 options := options{
46 headers: map[string]string{},
47 }
48 for _, o := range opts {
49 o(&options)
50 }
51
52 options.name = cmp.Or(options.name, Name)
53
54 return &provider{
55 options: options,
56 }, nil
57}
58
59// WithBaseURL sets the base URL for the Google provider.
60func WithBaseURL(baseURL string) Option {
61 return func(o *options) {
62 o.baseURL = baseURL
63 }
64}
65
66// WithGeminiAPIKey sets the Gemini API key for the Google provider.
67func WithGeminiAPIKey(apiKey string) Option {
68 return func(o *options) {
69 o.backend = genai.BackendGeminiAPI
70 o.apiKey = apiKey
71 o.project = ""
72 o.location = ""
73 }
74}
75
76// WithVertex configures the Google provider to use Vertex AI.
77func WithVertex(project, location string) Option {
78 if project == "" || location == "" {
79 panic("project and location must be provided")
80 }
81 return func(o *options) {
82 o.backend = genai.BackendVertexAI
83 o.apiKey = ""
84 o.project = project
85 o.location = location
86 }
87}
88
89// WithSkipAuth configures whether to skip authentication for the Google provider.
90func WithSkipAuth(skipAuth bool) Option {
91 return func(o *options) {
92 o.skipAuth = skipAuth
93 }
94}
95
96// WithName sets the name for the Google provider.
97func WithName(name string) Option {
98 return func(o *options) {
99 o.name = name
100 }
101}
102
103// WithHeaders sets the headers for the Google provider.
104func WithHeaders(headers map[string]string) Option {
105 return func(o *options) {
106 maps.Copy(o.headers, headers)
107 }
108}
109
110// WithHTTPClient sets the HTTP client for the Google provider.
111func WithHTTPClient(client *http.Client) Option {
112 return func(o *options) {
113 o.client = client
114 }
115}
116
117func (*provider) Name() string {
118 return Name
119}
120
121type languageModel struct {
122 provider string
123 modelID string
124 client *genai.Client
125 providerOptions options
126}
127
128// LanguageModel implements fantasy.Provider.
129func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
130 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
131 p, err := anthropic.New(
132 anthropic.WithVertex(a.options.project, a.options.location),
133 anthropic.WithHTTPClient(a.options.client),
134 anthropic.WithSkipAuth(a.options.skipAuth),
135 )
136 if err != nil {
137 return nil, err
138 }
139 return p.LanguageModel(modelID)
140 }
141
142 cc := &genai.ClientConfig{
143 HTTPClient: a.options.client,
144 Backend: a.options.backend,
145 APIKey: a.options.apiKey,
146 Project: a.options.project,
147 Location: a.options.location,
148 }
149 if a.options.skipAuth {
150 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
151 }
152
153 if a.options.baseURL != "" || len(a.options.headers) > 0 {
154 headers := http.Header{}
155 for k, v := range a.options.headers {
156 headers.Add(k, v)
157 }
158 cc.HTTPOptions = genai.HTTPOptions{
159 BaseURL: a.options.baseURL,
160 Headers: headers,
161 }
162 }
163 client, err := genai.NewClient(context.Background(), cc)
164 if err != nil {
165 return nil, err
166 }
167 return &languageModel{
168 modelID: modelID,
169 provider: a.options.name,
170 providerOptions: a.options,
171 client: client,
172 }, nil
173}
174
175func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
176 config := &genai.GenerateContentConfig{}
177
178 providerOptions := &ProviderOptions{}
179 if v, ok := call.ProviderOptions[Name]; ok {
180 providerOptions, ok = v.(*ProviderOptions)
181 if !ok {
182 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
183 }
184 }
185
186 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
187
188 if providerOptions.ThinkingConfig != nil {
189 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
190 *providerOptions.ThinkingConfig.IncludeThoughts &&
191 strings.HasPrefix(g.provider, "google.vertex.") {
192 warnings = append(warnings, fantasy.CallWarning{
193 Type: fantasy.CallWarningTypeOther,
194 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
195 "and might not be supported or could behave unexpectedly with the current Google provider " +
196 fmt.Sprintf("(%s)", g.provider),
197 })
198 }
199
200 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
201 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
202 warnings = append(warnings, fantasy.CallWarning{
203 Type: fantasy.CallWarningTypeOther,
204 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
205 })
206 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
207 }
208 }
209
210 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
211
212 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
213 if len(content) > 0 && content[0].Role == genai.RoleUser {
214 systemParts := []string{}
215 for _, sp := range systemInstructions.Parts {
216 systemParts = append(systemParts, sp.Text)
217 }
218 systemMsg := strings.Join(systemParts, "\n")
219 content[0].Parts = append([]*genai.Part{
220 {
221 Text: systemMsg + "\n\n",
222 },
223 }, content[0].Parts...)
224 systemInstructions = nil
225 }
226 }
227
228 config.SystemInstruction = systemInstructions
229
230 if call.MaxOutputTokens != nil {
231 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
232 }
233
234 if call.Temperature != nil {
235 tmp := float32(*call.Temperature)
236 config.Temperature = &tmp
237 }
238 if call.TopK != nil {
239 tmp := float32(*call.TopK)
240 config.TopK = &tmp
241 }
242 if call.TopP != nil {
243 tmp := float32(*call.TopP)
244 config.TopP = &tmp
245 }
246 if call.FrequencyPenalty != nil {
247 tmp := float32(*call.FrequencyPenalty)
248 config.FrequencyPenalty = &tmp
249 }
250 if call.PresencePenalty != nil {
251 tmp := float32(*call.PresencePenalty)
252 config.PresencePenalty = &tmp
253 }
254
255 if providerOptions.ThinkingConfig != nil {
256 config.ThinkingConfig = &genai.ThinkingConfig{}
257 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
258 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
259 }
260 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
261 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
262 config.ThinkingConfig.ThinkingBudget = &tmp
263 }
264 }
265 for _, safetySetting := range providerOptions.SafetySettings {
266 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
267 Category: genai.HarmCategory(safetySetting.Category),
268 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
269 })
270 }
271 if providerOptions.CachedContent != "" {
272 config.CachedContent = providerOptions.CachedContent
273 }
274
275 if len(call.Tools) > 0 {
276 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
277 config.ToolConfig = toolChoice
278 config.Tools = append(config.Tools, &genai.Tool{
279 FunctionDeclarations: tools,
280 })
281 warnings = append(warnings, toolWarnings...)
282 }
283
284 return config, content, warnings, nil
285}
286
287func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
288 var systemInstructions *genai.Content
289 var content []*genai.Content
290 var warnings []fantasy.CallWarning
291
292 finishedSystemBlock := false
293 for _, msg := range prompt {
294 switch msg.Role {
295 case fantasy.MessageRoleSystem:
296 if finishedSystemBlock {
297 // skip multiple system messages that are separated by user/assistant messages
298 // TODO: see if we need to send error here?
299 continue
300 }
301 finishedSystemBlock = true
302
303 var systemMessages []string
304 for _, part := range msg.Content {
305 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
306 if !ok || text.Text == "" {
307 continue
308 }
309 systemMessages = append(systemMessages, text.Text)
310 }
311 if len(systemMessages) > 0 {
312 systemInstructions = &genai.Content{
313 Parts: []*genai.Part{
314 {
315 Text: strings.Join(systemMessages, "\n"),
316 },
317 },
318 }
319 }
320 case fantasy.MessageRoleUser:
321 var parts []*genai.Part
322 for _, part := range msg.Content {
323 switch part.GetType() {
324 case fantasy.ContentTypeText:
325 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
326 if !ok || text.Text == "" {
327 continue
328 }
329 parts = append(parts, &genai.Part{
330 Text: text.Text,
331 })
332 case fantasy.ContentTypeFile:
333 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
334 if !ok {
335 continue
336 }
337 parts = append(parts, &genai.Part{
338 InlineData: &genai.Blob{
339 Data: file.Data,
340 MIMEType: file.MediaType,
341 },
342 })
343 }
344 }
345 if len(parts) > 0 {
346 content = append(content, &genai.Content{
347 Role: genai.RoleUser,
348 Parts: parts,
349 })
350 }
351 case fantasy.MessageRoleAssistant:
352 var parts []*genai.Part
353 for _, part := range msg.Content {
354 switch part.GetType() {
355 case fantasy.ContentTypeText:
356 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
357 if !ok || text.Text == "" {
358 continue
359 }
360 parts = append(parts, &genai.Part{
361 Text: text.Text,
362 })
363 case fantasy.ContentTypeToolCall:
364 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
365 if !ok {
366 continue
367 }
368
369 var result map[string]any
370 err := json.Unmarshal([]byte(toolCall.Input), &result)
371 if err != nil {
372 continue
373 }
374 parts = append(parts, &genai.Part{
375 FunctionCall: &genai.FunctionCall{
376 ID: toolCall.ToolCallID,
377 Name: toolCall.ToolName,
378 Args: result,
379 },
380 })
381 }
382 }
383 if len(parts) > 0 {
384 content = append(content, &genai.Content{
385 Role: genai.RoleModel,
386 Parts: parts,
387 })
388 }
389 case fantasy.MessageRoleTool:
390 var parts []*genai.Part
391 for _, part := range msg.Content {
392 switch part.GetType() {
393 case fantasy.ContentTypeToolResult:
394 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
395 if !ok {
396 continue
397 }
398 var toolCall fantasy.ToolCallPart
399 for _, m := range prompt {
400 if m.Role == fantasy.MessageRoleAssistant {
401 for _, content := range m.Content {
402 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
403 if !ok {
404 continue
405 }
406 if tc.ToolCallID == result.ToolCallID {
407 toolCall = tc
408 break
409 }
410 }
411 }
412 }
413 switch result.Output.GetType() {
414 case fantasy.ToolResultContentTypeText:
415 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
416 if !ok {
417 continue
418 }
419 response := map[string]any{"result": content.Text}
420 parts = append(parts, &genai.Part{
421 FunctionResponse: &genai.FunctionResponse{
422 ID: result.ToolCallID,
423 Response: response,
424 Name: toolCall.ToolName,
425 },
426 })
427
428 case fantasy.ToolResultContentTypeError:
429 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
430 if !ok {
431 continue
432 }
433 response := map[string]any{"result": content.Error.Error()}
434 parts = append(parts, &genai.Part{
435 FunctionResponse: &genai.FunctionResponse{
436 ID: result.ToolCallID,
437 Response: response,
438 Name: toolCall.ToolName,
439 },
440 })
441 }
442 }
443 }
444 if len(parts) > 0 {
445 content = append(content, &genai.Content{
446 Role: genai.RoleUser,
447 Parts: parts,
448 })
449 }
450 default:
451 panic("unsupported message role: " + msg.Role)
452 }
453 }
454 return systemInstructions, content, warnings
455}
456
457// Generate implements fantasy.LanguageModel.
458func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
459 config, contents, warnings, err := g.prepareParams(call)
460 if err != nil {
461 return nil, err
462 }
463
464 lastMessage, history, ok := slice.Pop(contents)
465 if !ok {
466 return nil, errors.New("no messages to send")
467 }
468
469 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
470 if err != nil {
471 return nil, err
472 }
473
474 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
475 if err != nil {
476 return nil, err
477 }
478
479 return mapResponse(response, warnings)
480}
481
482// Model implements fantasy.LanguageModel.
483func (g *languageModel) Model() string {
484 return g.modelID
485}
486
487// Provider implements fantasy.LanguageModel.
488func (g *languageModel) Provider() string {
489 return g.provider
490}
491
492// Stream implements fantasy.LanguageModel.
493func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
494 config, contents, warnings, err := g.prepareParams(call)
495 if err != nil {
496 return nil, err
497 }
498
499 lastMessage, history, ok := slice.Pop(contents)
500 if !ok {
501 return nil, errors.New("no messages to send")
502 }
503
504 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
505 if err != nil {
506 return nil, err
507 }
508
509 return func(yield func(fantasy.StreamPart) bool) {
510 if len(warnings) > 0 {
511 if !yield(fantasy.StreamPart{
512 Type: fantasy.StreamPartTypeWarnings,
513 Warnings: warnings,
514 }) {
515 return
516 }
517 }
518
519 var currentContent string
520 var toolCalls []fantasy.ToolCallContent
521 var isActiveText bool
522 var isActiveReasoning bool
523 var blockCounter int
524 var currentTextBlockID string
525 var currentReasoningBlockID string
526 var usage fantasy.Usage
527 var lastFinishReason fantasy.FinishReason
528
529 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
530 if err != nil {
531 yield(fantasy.StreamPart{
532 Type: fantasy.StreamPartTypeError,
533 Error: err,
534 })
535 return
536 }
537
538 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
539 for _, part := range resp.Candidates[0].Content.Parts {
540 switch {
541 case part.Text != "":
542 delta := part.Text
543 if delta != "" {
544 // Check if this is a reasoning/thought part
545 if part.Thought {
546 // End any active text block before starting reasoning
547 if isActiveText {
548 isActiveText = false
549 if !yield(fantasy.StreamPart{
550 Type: fantasy.StreamPartTypeTextEnd,
551 ID: currentTextBlockID,
552 }) {
553 return
554 }
555 }
556
557 // Start new reasoning block if not already active
558 if !isActiveReasoning {
559 isActiveReasoning = true
560 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
561 blockCounter++
562 if !yield(fantasy.StreamPart{
563 Type: fantasy.StreamPartTypeReasoningStart,
564 ID: currentReasoningBlockID,
565 }) {
566 return
567 }
568 }
569
570 if !yield(fantasy.StreamPart{
571 Type: fantasy.StreamPartTypeReasoningDelta,
572 ID: currentReasoningBlockID,
573 Delta: delta,
574 }) {
575 return
576 }
577 } else {
578 // Regular text part
579 // End any active reasoning block before starting text
580 if isActiveReasoning {
581 isActiveReasoning = false
582 if !yield(fantasy.StreamPart{
583 Type: fantasy.StreamPartTypeReasoningEnd,
584 ID: currentReasoningBlockID,
585 }) {
586 return
587 }
588 }
589
590 // Start new text block if not already active
591 if !isActiveText {
592 isActiveText = true
593 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
594 blockCounter++
595 if !yield(fantasy.StreamPart{
596 Type: fantasy.StreamPartTypeTextStart,
597 ID: currentTextBlockID,
598 }) {
599 return
600 }
601 }
602
603 if !yield(fantasy.StreamPart{
604 Type: fantasy.StreamPartTypeTextDelta,
605 ID: currentTextBlockID,
606 Delta: delta,
607 }) {
608 return
609 }
610 currentContent += delta
611 }
612 }
613 case part.FunctionCall != nil:
614 // End any active text or reasoning blocks
615 if isActiveText {
616 isActiveText = false
617 if !yield(fantasy.StreamPart{
618 Type: fantasy.StreamPartTypeTextEnd,
619 ID: currentTextBlockID,
620 }) {
621 return
622 }
623 }
624 if isActiveReasoning {
625 isActiveReasoning = false
626 if !yield(fantasy.StreamPart{
627 Type: fantasy.StreamPartTypeReasoningEnd,
628 ID: currentReasoningBlockID,
629 }) {
630 return
631 }
632 }
633
634 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
635
636 args, err := json.Marshal(part.FunctionCall.Args)
637 if err != nil {
638 yield(fantasy.StreamPart{
639 Type: fantasy.StreamPartTypeError,
640 Error: err,
641 })
642 return
643 }
644
645 if !yield(fantasy.StreamPart{
646 Type: fantasy.StreamPartTypeToolInputStart,
647 ID: toolCallID,
648 ToolCallName: part.FunctionCall.Name,
649 }) {
650 return
651 }
652
653 if !yield(fantasy.StreamPart{
654 Type: fantasy.StreamPartTypeToolInputDelta,
655 ID: toolCallID,
656 Delta: string(args),
657 }) {
658 return
659 }
660
661 if !yield(fantasy.StreamPart{
662 Type: fantasy.StreamPartTypeToolInputEnd,
663 ID: toolCallID,
664 }) {
665 return
666 }
667
668 if !yield(fantasy.StreamPart{
669 Type: fantasy.StreamPartTypeToolCall,
670 ID: toolCallID,
671 ToolCallName: part.FunctionCall.Name,
672 ToolCallInput: string(args),
673 ProviderExecuted: false,
674 }) {
675 return
676 }
677
678 toolCalls = append(toolCalls, fantasy.ToolCallContent{
679 ToolCallID: toolCallID,
680 ToolName: part.FunctionCall.Name,
681 Input: string(args),
682 ProviderExecuted: false,
683 })
684 }
685 }
686 }
687
688 if resp.UsageMetadata != nil {
689 usage = mapUsage(resp.UsageMetadata)
690 }
691
692 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
693 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
694 }
695 }
696
697 // Close any open blocks before finishing
698 if isActiveText {
699 if !yield(fantasy.StreamPart{
700 Type: fantasy.StreamPartTypeTextEnd,
701 ID: currentTextBlockID,
702 }) {
703 return
704 }
705 }
706 if isActiveReasoning {
707 if !yield(fantasy.StreamPart{
708 Type: fantasy.StreamPartTypeReasoningEnd,
709 ID: currentReasoningBlockID,
710 }) {
711 return
712 }
713 }
714
715 finishReason := lastFinishReason
716 if len(toolCalls) > 0 {
717 finishReason = fantasy.FinishReasonToolCalls
718 } else if finishReason == "" {
719 finishReason = fantasy.FinishReasonStop
720 }
721
722 yield(fantasy.StreamPart{
723 Type: fantasy.StreamPartTypeFinish,
724 Usage: usage,
725 FinishReason: finishReason,
726 })
727 }, nil
728}
729
730func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
731 for _, tool := range tools {
732 if tool.GetType() == fantasy.ToolTypeFunction {
733 ft, ok := tool.(fantasy.FunctionTool)
734 if !ok {
735 continue
736 }
737
738 required := []string{}
739 var properties map[string]any
740 if props, ok := ft.InputSchema["properties"]; ok {
741 properties, _ = props.(map[string]any)
742 }
743 if req, ok := ft.InputSchema["required"]; ok {
744 if reqArr, ok := req.([]string); ok {
745 required = reqArr
746 }
747 }
748 declaration := &genai.FunctionDeclaration{
749 Name: ft.Name,
750 Description: ft.Description,
751 Parameters: &genai.Schema{
752 Type: genai.TypeObject,
753 Properties: convertSchemaProperties(properties),
754 Required: required,
755 },
756 }
757 googleTools = append(googleTools, declaration)
758 continue
759 }
760 // TODO: handle provider tool calls
761 warnings = append(warnings, fantasy.CallWarning{
762 Type: fantasy.CallWarningTypeUnsupportedTool,
763 Tool: tool,
764 Message: "tool is not supported",
765 })
766 }
767 if toolChoice == nil {
768 return googleTools, googleToolChoice, warnings
769 }
770 switch *toolChoice {
771 case fantasy.ToolChoiceAuto:
772 googleToolChoice = &genai.ToolConfig{
773 FunctionCallingConfig: &genai.FunctionCallingConfig{
774 Mode: genai.FunctionCallingConfigModeAuto,
775 },
776 }
777 case fantasy.ToolChoiceRequired:
778 googleToolChoice = &genai.ToolConfig{
779 FunctionCallingConfig: &genai.FunctionCallingConfig{
780 Mode: genai.FunctionCallingConfigModeAny,
781 },
782 }
783 case fantasy.ToolChoiceNone:
784 googleToolChoice = &genai.ToolConfig{
785 FunctionCallingConfig: &genai.FunctionCallingConfig{
786 Mode: genai.FunctionCallingConfigModeNone,
787 },
788 }
789 default:
790 googleToolChoice = &genai.ToolConfig{
791 FunctionCallingConfig: &genai.FunctionCallingConfig{
792 Mode: genai.FunctionCallingConfigModeAny,
793 AllowedFunctionNames: []string{
794 string(*toolChoice),
795 },
796 },
797 }
798 }
799 return googleTools, googleToolChoice, warnings
800}
801
802func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
803 properties := make(map[string]*genai.Schema)
804
805 for name, param := range parameters {
806 properties[name] = convertToSchema(param)
807 }
808
809 return properties
810}
811
812func convertToSchema(param any) *genai.Schema {
813 schema := &genai.Schema{Type: genai.TypeString}
814
815 paramMap, ok := param.(map[string]any)
816 if !ok {
817 return schema
818 }
819
820 if desc, ok := paramMap["description"].(string); ok {
821 schema.Description = desc
822 }
823
824 typeVal, hasType := paramMap["type"]
825 if !hasType {
826 return schema
827 }
828
829 typeStr, ok := typeVal.(string)
830 if !ok {
831 return schema
832 }
833
834 schema.Type = mapJSONTypeToGoogle(typeStr)
835
836 switch typeStr {
837 case "array":
838 schema.Items = processArrayItems(paramMap)
839 case "object":
840 if props, ok := paramMap["properties"].(map[string]any); ok {
841 schema.Properties = convertSchemaProperties(props)
842 }
843 }
844
845 return schema
846}
847
848func processArrayItems(paramMap map[string]any) *genai.Schema {
849 items, ok := paramMap["items"].(map[string]any)
850 if !ok {
851 return nil
852 }
853
854 return convertToSchema(items)
855}
856
857func mapJSONTypeToGoogle(jsonType string) genai.Type {
858 switch jsonType {
859 case "string":
860 return genai.TypeString
861 case "number":
862 return genai.TypeNumber
863 case "integer":
864 return genai.TypeInteger
865 case "boolean":
866 return genai.TypeBoolean
867 case "array":
868 return genai.TypeArray
869 case "object":
870 return genai.TypeObject
871 default:
872 return genai.TypeString // Default to string for unknown types
873 }
874}
875
876func mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
877 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
878 return nil, errors.New("no response from model")
879 }
880
881 var (
882 content []fantasy.Content
883 finishReason fantasy.FinishReason
884 hasToolCalls bool
885 candidate = response.Candidates[0]
886 )
887
888 for _, part := range candidate.Content.Parts {
889 switch {
890 case part.Text != "":
891 if part.Thought {
892 content = append(content, fantasy.ReasoningContent{Text: part.Text})
893 } else {
894 content = append(content, fantasy.TextContent{Text: part.Text})
895 }
896 case part.FunctionCall != nil:
897 input, err := json.Marshal(part.FunctionCall.Args)
898 if err != nil {
899 return nil, err
900 }
901 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
902 content = append(content, fantasy.ToolCallContent{
903 ToolCallID: toolCallID,
904 ToolName: part.FunctionCall.Name,
905 Input: string(input),
906 ProviderExecuted: false,
907 })
908 hasToolCalls = true
909 default:
910 // Silently skip unknown part types instead of erroring
911 // This allows for forward compatibility with new part types
912 }
913 }
914
915 if hasToolCalls {
916 finishReason = fantasy.FinishReasonToolCalls
917 } else {
918 finishReason = mapFinishReason(candidate.FinishReason)
919 }
920
921 return &fantasy.Response{
922 Content: content,
923 Usage: mapUsage(response.UsageMetadata),
924 FinishReason: finishReason,
925 Warnings: warnings,
926 }, nil
927}
928
929func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
930 switch reason {
931 case genai.FinishReasonStop:
932 return fantasy.FinishReasonStop
933 case genai.FinishReasonMaxTokens:
934 return fantasy.FinishReasonLength
935 case genai.FinishReasonSafety,
936 genai.FinishReasonBlocklist,
937 genai.FinishReasonProhibitedContent,
938 genai.FinishReasonSPII,
939 genai.FinishReasonImageSafety:
940 return fantasy.FinishReasonContentFilter
941 case genai.FinishReasonRecitation,
942 genai.FinishReasonLanguage,
943 genai.FinishReasonMalformedFunctionCall:
944 return fantasy.FinishReasonError
945 case genai.FinishReasonOther:
946 return fantasy.FinishReasonOther
947 default:
948 return fantasy.FinishReasonUnknown
949 }
950}
951
952func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
953 return fantasy.Usage{
954 InputTokens: int64(usage.ToolUsePromptTokenCount),
955 OutputTokens: int64(usage.CandidatesTokenCount),
956 TotalTokens: int64(usage.TotalTokenCount),
957 ReasoningTokens: int64(usage.ThoughtsTokenCount),
958 CacheCreationTokens: int64(usage.CachedContentTokenCount),
959 CacheReadTokens: 0,
960 }
961}