1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "cloud.google.com/go/auth"
16 "github.com/charmbracelet/x/exp/slice"
17 "github.com/google/uuid"
18 "google.golang.org/genai"
19)
20
21// Name is the name of the Google provider.
22const Name = "google"
23
24type provider struct {
25 options options
26}
27
28type options struct {
29 apiKey string
30 name string
31 baseURL string
32 headers map[string]string
33 client *http.Client
34 backend genai.Backend
35 project string
36 location string
37 skipAuth bool
38}
39
40// Option defines a function that configures Google provider options.
41type Option = func(*options)
42
43// New creates a new Google provider with the given options.
44func New(opts ...Option) fantasy.Provider {
45 options := options{
46 headers: map[string]string{},
47 }
48 for _, o := range opts {
49 o(&options)
50 }
51
52 options.name = cmp.Or(options.name, Name)
53
54 return &provider{
55 options: options,
56 }
57}
58
59// WithBaseURL sets the base URL for the Google provider.
60func WithBaseURL(baseURL string) Option {
61 return func(o *options) {
62 o.baseURL = baseURL
63 }
64}
65
66// WithGeminiAPIKey sets the Gemini API key for the Google provider.
67func WithGeminiAPIKey(apiKey string) Option {
68 return func(o *options) {
69 o.backend = genai.BackendGeminiAPI
70 o.apiKey = apiKey
71 o.project = ""
72 o.location = ""
73 }
74}
75
76// WithVertex configures the Google provider to use Vertex AI.
77func WithVertex(project, location string) Option {
78 if project == "" || location == "" {
79 panic("project and location must be provided")
80 }
81 return func(o *options) {
82 o.backend = genai.BackendVertexAI
83 o.apiKey = ""
84 o.project = project
85 o.location = location
86 }
87}
88
89// WithSkipAuth configures whether to skip authentication for the Google provider.
90func WithSkipAuth(skipAuth bool) Option {
91 return func(o *options) {
92 o.skipAuth = skipAuth
93 }
94}
95
96// WithName sets the name for the Google provider.
97func WithName(name string) Option {
98 return func(o *options) {
99 o.name = name
100 }
101}
102
103// WithHeaders sets the headers for the Google provider.
104func WithHeaders(headers map[string]string) Option {
105 return func(o *options) {
106 maps.Copy(o.headers, headers)
107 }
108}
109
110// WithHTTPClient sets the HTTP client for the Google provider.
111func WithHTTPClient(client *http.Client) Option {
112 return func(o *options) {
113 o.client = client
114 }
115}
116
117func (*provider) Name() string {
118 return Name
119}
120
121type languageModel struct {
122 provider string
123 modelID string
124 client *genai.Client
125 providerOptions options
126}
127
128// LanguageModel implements fantasy.Provider.
129func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
130 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
131 return anthropic.New(
132 anthropic.WithVertex(a.options.project, a.options.location),
133 anthropic.WithHTTPClient(a.options.client),
134 anthropic.WithSkipAuth(a.options.skipAuth),
135 ).LanguageModel(modelID)
136 }
137
138 cc := &genai.ClientConfig{
139 HTTPClient: a.options.client,
140 Backend: a.options.backend,
141 APIKey: a.options.apiKey,
142 Project: a.options.project,
143 Location: a.options.location,
144 }
145 if a.options.skipAuth {
146 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
147 }
148
149 if a.options.baseURL != "" || len(a.options.headers) > 0 {
150 headers := http.Header{}
151 for k, v := range a.options.headers {
152 headers.Add(k, v)
153 }
154 cc.HTTPOptions = genai.HTTPOptions{
155 BaseURL: a.options.baseURL,
156 Headers: headers,
157 }
158 }
159 client, err := genai.NewClient(context.Background(), cc)
160 if err != nil {
161 return nil, err
162 }
163 return &languageModel{
164 modelID: modelID,
165 provider: a.options.name,
166 providerOptions: a.options,
167 client: client,
168 }, nil
169}
170
171func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
172 config := &genai.GenerateContentConfig{}
173
174 providerOptions := &ProviderOptions{}
175 if v, ok := call.ProviderOptions[Name]; ok {
176 providerOptions, ok = v.(*ProviderOptions)
177 if !ok {
178 return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
179 }
180 }
181
182 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
183
184 if providerOptions.ThinkingConfig != nil {
185 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
186 *providerOptions.ThinkingConfig.IncludeThoughts &&
187 strings.HasPrefix(g.provider, "google.vertex.") {
188 warnings = append(warnings, fantasy.CallWarning{
189 Type: fantasy.CallWarningTypeOther,
190 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
191 "and might not be supported or could behave unexpectedly with the current Google provider " +
192 fmt.Sprintf("(%s)", g.provider),
193 })
194 }
195
196 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
197 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
198 warnings = append(warnings, fantasy.CallWarning{
199 Type: fantasy.CallWarningTypeOther,
200 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
201 })
202 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
203 }
204 }
205
206 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
207
208 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
209 if len(content) > 0 && content[0].Role == genai.RoleUser {
210 systemParts := []string{}
211 for _, sp := range systemInstructions.Parts {
212 systemParts = append(systemParts, sp.Text)
213 }
214 systemMsg := strings.Join(systemParts, "\n")
215 content[0].Parts = append([]*genai.Part{
216 {
217 Text: systemMsg + "\n\n",
218 },
219 }, content[0].Parts...)
220 systemInstructions = nil
221 }
222 }
223
224 config.SystemInstruction = systemInstructions
225
226 if call.MaxOutputTokens != nil {
227 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
228 }
229
230 if call.Temperature != nil {
231 tmp := float32(*call.Temperature)
232 config.Temperature = &tmp
233 }
234 if call.TopK != nil {
235 tmp := float32(*call.TopK)
236 config.TopK = &tmp
237 }
238 if call.TopP != nil {
239 tmp := float32(*call.TopP)
240 config.TopP = &tmp
241 }
242 if call.FrequencyPenalty != nil {
243 tmp := float32(*call.FrequencyPenalty)
244 config.FrequencyPenalty = &tmp
245 }
246 if call.PresencePenalty != nil {
247 tmp := float32(*call.PresencePenalty)
248 config.PresencePenalty = &tmp
249 }
250
251 if providerOptions.ThinkingConfig != nil {
252 config.ThinkingConfig = &genai.ThinkingConfig{}
253 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
254 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
255 }
256 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
257 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
258 config.ThinkingConfig.ThinkingBudget = &tmp
259 }
260 }
261 for _, safetySetting := range providerOptions.SafetySettings {
262 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
263 Category: genai.HarmCategory(safetySetting.Category),
264 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
265 })
266 }
267 if providerOptions.CachedContent != "" {
268 config.CachedContent = providerOptions.CachedContent
269 }
270
271 if len(call.Tools) > 0 {
272 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
273 config.ToolConfig = toolChoice
274 config.Tools = append(config.Tools, &genai.Tool{
275 FunctionDeclarations: tools,
276 })
277 warnings = append(warnings, toolWarnings...)
278 }
279
280 return config, content, warnings, nil
281}
282
283func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
284 var systemInstructions *genai.Content
285 var content []*genai.Content
286 var warnings []fantasy.CallWarning
287
288 finishedSystemBlock := false
289 for _, msg := range prompt {
290 switch msg.Role {
291 case fantasy.MessageRoleSystem:
292 if finishedSystemBlock {
293 // skip multiple system messages that are separated by user/assistant messages
294 // TODO: see if we need to send error here?
295 continue
296 }
297 finishedSystemBlock = true
298
299 var systemMessages []string
300 for _, part := range msg.Content {
301 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
302 if !ok || text.Text == "" {
303 continue
304 }
305 systemMessages = append(systemMessages, text.Text)
306 }
307 if len(systemMessages) > 0 {
308 systemInstructions = &genai.Content{
309 Parts: []*genai.Part{
310 {
311 Text: strings.Join(systemMessages, "\n"),
312 },
313 },
314 }
315 }
316 case fantasy.MessageRoleUser:
317 var parts []*genai.Part
318 for _, part := range msg.Content {
319 switch part.GetType() {
320 case fantasy.ContentTypeText:
321 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
322 if !ok || text.Text == "" {
323 continue
324 }
325 parts = append(parts, &genai.Part{
326 Text: text.Text,
327 })
328 case fantasy.ContentTypeFile:
329 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
330 if !ok {
331 continue
332 }
333 parts = append(parts, &genai.Part{
334 InlineData: &genai.Blob{
335 Data: file.Data,
336 MIMEType: file.MediaType,
337 },
338 })
339 }
340 }
341 if len(parts) > 0 {
342 content = append(content, &genai.Content{
343 Role: genai.RoleUser,
344 Parts: parts,
345 })
346 }
347 case fantasy.MessageRoleAssistant:
348 var parts []*genai.Part
349 for _, part := range msg.Content {
350 switch part.GetType() {
351 case fantasy.ContentTypeText:
352 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
353 if !ok || text.Text == "" {
354 continue
355 }
356 parts = append(parts, &genai.Part{
357 Text: text.Text,
358 })
359 case fantasy.ContentTypeToolCall:
360 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
361 if !ok {
362 continue
363 }
364
365 var result map[string]any
366 err := json.Unmarshal([]byte(toolCall.Input), &result)
367 if err != nil {
368 continue
369 }
370 parts = append(parts, &genai.Part{
371 FunctionCall: &genai.FunctionCall{
372 ID: toolCall.ToolCallID,
373 Name: toolCall.ToolName,
374 Args: result,
375 },
376 })
377 }
378 }
379 if len(parts) > 0 {
380 content = append(content, &genai.Content{
381 Role: genai.RoleModel,
382 Parts: parts,
383 })
384 }
385 case fantasy.MessageRoleTool:
386 var parts []*genai.Part
387 for _, part := range msg.Content {
388 switch part.GetType() {
389 case fantasy.ContentTypeToolResult:
390 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
391 if !ok {
392 continue
393 }
394 var toolCall fantasy.ToolCallPart
395 for _, m := range prompt {
396 if m.Role == fantasy.MessageRoleAssistant {
397 for _, content := range m.Content {
398 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
399 if !ok {
400 continue
401 }
402 if tc.ToolCallID == result.ToolCallID {
403 toolCall = tc
404 break
405 }
406 }
407 }
408 }
409 switch result.Output.GetType() {
410 case fantasy.ToolResultContentTypeText:
411 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
412 if !ok {
413 continue
414 }
415 response := map[string]any{"result": content.Text}
416 parts = append(parts, &genai.Part{
417 FunctionResponse: &genai.FunctionResponse{
418 ID: result.ToolCallID,
419 Response: response,
420 Name: toolCall.ToolName,
421 },
422 })
423
424 case fantasy.ToolResultContentTypeError:
425 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
426 if !ok {
427 continue
428 }
429 response := map[string]any{"result": content.Error.Error()}
430 parts = append(parts, &genai.Part{
431 FunctionResponse: &genai.FunctionResponse{
432 ID: result.ToolCallID,
433 Response: response,
434 Name: toolCall.ToolName,
435 },
436 })
437 }
438 }
439 }
440 if len(parts) > 0 {
441 content = append(content, &genai.Content{
442 Role: genai.RoleUser,
443 Parts: parts,
444 })
445 }
446 default:
447 panic("unsupported message role: " + msg.Role)
448 }
449 }
450 return systemInstructions, content, warnings
451}
452
453// Generate implements fantasy.LanguageModel.
454func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
455 config, contents, warnings, err := g.prepareParams(call)
456 if err != nil {
457 return nil, err
458 }
459
460 lastMessage, history, ok := slice.Pop(contents)
461 if !ok {
462 return nil, errors.New("no messages to send")
463 }
464
465 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
466 if err != nil {
467 return nil, err
468 }
469
470 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
471 if err != nil {
472 return nil, err
473 }
474
475 return mapResponse(response, warnings)
476}
477
478// Model implements fantasy.LanguageModel.
479func (g *languageModel) Model() string {
480 return g.modelID
481}
482
483// Provider implements fantasy.LanguageModel.
484func (g *languageModel) Provider() string {
485 return g.provider
486}
487
488// Stream implements fantasy.LanguageModel.
489func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
490 config, contents, warnings, err := g.prepareParams(call)
491 if err != nil {
492 return nil, err
493 }
494
495 lastMessage, history, ok := slice.Pop(contents)
496 if !ok {
497 return nil, errors.New("no messages to send")
498 }
499
500 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
501 if err != nil {
502 return nil, err
503 }
504
505 return func(yield func(fantasy.StreamPart) bool) {
506 if len(warnings) > 0 {
507 if !yield(fantasy.StreamPart{
508 Type: fantasy.StreamPartTypeWarnings,
509 Warnings: warnings,
510 }) {
511 return
512 }
513 }
514
515 var currentContent string
516 var toolCalls []fantasy.ToolCallContent
517 var isActiveText bool
518 var isActiveReasoning bool
519 var blockCounter int
520 var currentTextBlockID string
521 var currentReasoningBlockID string
522 var usage fantasy.Usage
523 var lastFinishReason fantasy.FinishReason
524
525 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
526 if err != nil {
527 yield(fantasy.StreamPart{
528 Type: fantasy.StreamPartTypeError,
529 Error: err,
530 })
531 return
532 }
533
534 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
535 for _, part := range resp.Candidates[0].Content.Parts {
536 switch {
537 case part.Text != "":
538 delta := part.Text
539 if delta != "" {
540 // Check if this is a reasoning/thought part
541 if part.Thought {
542 // End any active text block before starting reasoning
543 if isActiveText {
544 isActiveText = false
545 if !yield(fantasy.StreamPart{
546 Type: fantasy.StreamPartTypeTextEnd,
547 ID: currentTextBlockID,
548 }) {
549 return
550 }
551 }
552
553 // Start new reasoning block if not already active
554 if !isActiveReasoning {
555 isActiveReasoning = true
556 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
557 blockCounter++
558 if !yield(fantasy.StreamPart{
559 Type: fantasy.StreamPartTypeReasoningStart,
560 ID: currentReasoningBlockID,
561 }) {
562 return
563 }
564 }
565
566 if !yield(fantasy.StreamPart{
567 Type: fantasy.StreamPartTypeReasoningDelta,
568 ID: currentReasoningBlockID,
569 Delta: delta,
570 }) {
571 return
572 }
573 } else {
574 // Regular text part
575 // End any active reasoning block before starting text
576 if isActiveReasoning {
577 isActiveReasoning = false
578 if !yield(fantasy.StreamPart{
579 Type: fantasy.StreamPartTypeReasoningEnd,
580 ID: currentReasoningBlockID,
581 }) {
582 return
583 }
584 }
585
586 // Start new text block if not already active
587 if !isActiveText {
588 isActiveText = true
589 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
590 blockCounter++
591 if !yield(fantasy.StreamPart{
592 Type: fantasy.StreamPartTypeTextStart,
593 ID: currentTextBlockID,
594 }) {
595 return
596 }
597 }
598
599 if !yield(fantasy.StreamPart{
600 Type: fantasy.StreamPartTypeTextDelta,
601 ID: currentTextBlockID,
602 Delta: delta,
603 }) {
604 return
605 }
606 currentContent += delta
607 }
608 }
609 case part.FunctionCall != nil:
610 // End any active text or reasoning blocks
611 if isActiveText {
612 isActiveText = false
613 if !yield(fantasy.StreamPart{
614 Type: fantasy.StreamPartTypeTextEnd,
615 ID: currentTextBlockID,
616 }) {
617 return
618 }
619 }
620 if isActiveReasoning {
621 isActiveReasoning = false
622 if !yield(fantasy.StreamPart{
623 Type: fantasy.StreamPartTypeReasoningEnd,
624 ID: currentReasoningBlockID,
625 }) {
626 return
627 }
628 }
629
630 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
631
632 args, err := json.Marshal(part.FunctionCall.Args)
633 if err != nil {
634 yield(fantasy.StreamPart{
635 Type: fantasy.StreamPartTypeError,
636 Error: err,
637 })
638 return
639 }
640
641 if !yield(fantasy.StreamPart{
642 Type: fantasy.StreamPartTypeToolInputStart,
643 ID: toolCallID,
644 ToolCallName: part.FunctionCall.Name,
645 }) {
646 return
647 }
648
649 if !yield(fantasy.StreamPart{
650 Type: fantasy.StreamPartTypeToolInputDelta,
651 ID: toolCallID,
652 Delta: string(args),
653 }) {
654 return
655 }
656
657 if !yield(fantasy.StreamPart{
658 Type: fantasy.StreamPartTypeToolInputEnd,
659 ID: toolCallID,
660 }) {
661 return
662 }
663
664 if !yield(fantasy.StreamPart{
665 Type: fantasy.StreamPartTypeToolCall,
666 ID: toolCallID,
667 ToolCallName: part.FunctionCall.Name,
668 ToolCallInput: string(args),
669 ProviderExecuted: false,
670 }) {
671 return
672 }
673
674 toolCalls = append(toolCalls, fantasy.ToolCallContent{
675 ToolCallID: toolCallID,
676 ToolName: part.FunctionCall.Name,
677 Input: string(args),
678 ProviderExecuted: false,
679 })
680 }
681 }
682 }
683
684 if resp.UsageMetadata != nil {
685 usage = mapUsage(resp.UsageMetadata)
686 }
687
688 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
689 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
690 }
691 }
692
693 // Close any open blocks before finishing
694 if isActiveText {
695 if !yield(fantasy.StreamPart{
696 Type: fantasy.StreamPartTypeTextEnd,
697 ID: currentTextBlockID,
698 }) {
699 return
700 }
701 }
702 if isActiveReasoning {
703 if !yield(fantasy.StreamPart{
704 Type: fantasy.StreamPartTypeReasoningEnd,
705 ID: currentReasoningBlockID,
706 }) {
707 return
708 }
709 }
710
711 finishReason := lastFinishReason
712 if len(toolCalls) > 0 {
713 finishReason = fantasy.FinishReasonToolCalls
714 } else if finishReason == "" {
715 finishReason = fantasy.FinishReasonStop
716 }
717
718 yield(fantasy.StreamPart{
719 Type: fantasy.StreamPartTypeFinish,
720 Usage: usage,
721 FinishReason: finishReason,
722 })
723 }, nil
724}
725
726func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
727 for _, tool := range tools {
728 if tool.GetType() == fantasy.ToolTypeFunction {
729 ft, ok := tool.(fantasy.FunctionTool)
730 if !ok {
731 continue
732 }
733
734 required := []string{}
735 var properties map[string]any
736 if props, ok := ft.InputSchema["properties"]; ok {
737 properties, _ = props.(map[string]any)
738 }
739 if req, ok := ft.InputSchema["required"]; ok {
740 if reqArr, ok := req.([]string); ok {
741 required = reqArr
742 }
743 }
744 declaration := &genai.FunctionDeclaration{
745 Name: ft.Name,
746 Description: ft.Description,
747 Parameters: &genai.Schema{
748 Type: genai.TypeObject,
749 Properties: convertSchemaProperties(properties),
750 Required: required,
751 },
752 }
753 googleTools = append(googleTools, declaration)
754 continue
755 }
756 // TODO: handle provider tool calls
757 warnings = append(warnings, fantasy.CallWarning{
758 Type: fantasy.CallWarningTypeUnsupportedTool,
759 Tool: tool,
760 Message: "tool is not supported",
761 })
762 }
763 if toolChoice == nil {
764 return googleTools, googleToolChoice, warnings
765 }
766 switch *toolChoice {
767 case fantasy.ToolChoiceAuto:
768 googleToolChoice = &genai.ToolConfig{
769 FunctionCallingConfig: &genai.FunctionCallingConfig{
770 Mode: genai.FunctionCallingConfigModeAuto,
771 },
772 }
773 case fantasy.ToolChoiceRequired:
774 googleToolChoice = &genai.ToolConfig{
775 FunctionCallingConfig: &genai.FunctionCallingConfig{
776 Mode: genai.FunctionCallingConfigModeAny,
777 },
778 }
779 case fantasy.ToolChoiceNone:
780 googleToolChoice = &genai.ToolConfig{
781 FunctionCallingConfig: &genai.FunctionCallingConfig{
782 Mode: genai.FunctionCallingConfigModeNone,
783 },
784 }
785 default:
786 googleToolChoice = &genai.ToolConfig{
787 FunctionCallingConfig: &genai.FunctionCallingConfig{
788 Mode: genai.FunctionCallingConfigModeAny,
789 AllowedFunctionNames: []string{
790 string(*toolChoice),
791 },
792 },
793 }
794 }
795 return googleTools, googleToolChoice, warnings
796}
797
798func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
799 properties := make(map[string]*genai.Schema)
800
801 for name, param := range parameters {
802 properties[name] = convertToSchema(param)
803 }
804
805 return properties
806}
807
808func convertToSchema(param any) *genai.Schema {
809 schema := &genai.Schema{Type: genai.TypeString}
810
811 paramMap, ok := param.(map[string]any)
812 if !ok {
813 return schema
814 }
815
816 if desc, ok := paramMap["description"].(string); ok {
817 schema.Description = desc
818 }
819
820 typeVal, hasType := paramMap["type"]
821 if !hasType {
822 return schema
823 }
824
825 typeStr, ok := typeVal.(string)
826 if !ok {
827 return schema
828 }
829
830 schema.Type = mapJSONTypeToGoogle(typeStr)
831
832 switch typeStr {
833 case "array":
834 schema.Items = processArrayItems(paramMap)
835 case "object":
836 if props, ok := paramMap["properties"].(map[string]any); ok {
837 schema.Properties = convertSchemaProperties(props)
838 }
839 }
840
841 return schema
842}
843
844func processArrayItems(paramMap map[string]any) *genai.Schema {
845 items, ok := paramMap["items"].(map[string]any)
846 if !ok {
847 return nil
848 }
849
850 return convertToSchema(items)
851}
852
853func mapJSONTypeToGoogle(jsonType string) genai.Type {
854 switch jsonType {
855 case "string":
856 return genai.TypeString
857 case "number":
858 return genai.TypeNumber
859 case "integer":
860 return genai.TypeInteger
861 case "boolean":
862 return genai.TypeBoolean
863 case "array":
864 return genai.TypeArray
865 case "object":
866 return genai.TypeObject
867 default:
868 return genai.TypeString // Default to string for unknown types
869 }
870}
871
872func mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
873 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
874 return nil, errors.New("no response from model")
875 }
876
877 var (
878 content []fantasy.Content
879 finishReason fantasy.FinishReason
880 hasToolCalls bool
881 candidate = response.Candidates[0]
882 )
883
884 for _, part := range candidate.Content.Parts {
885 switch {
886 case part.Text != "":
887 if part.Thought {
888 content = append(content, fantasy.ReasoningContent{Text: part.Text})
889 } else {
890 content = append(content, fantasy.TextContent{Text: part.Text})
891 }
892 case part.FunctionCall != nil:
893 input, err := json.Marshal(part.FunctionCall.Args)
894 if err != nil {
895 return nil, err
896 }
897 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
898 content = append(content, fantasy.ToolCallContent{
899 ToolCallID: toolCallID,
900 ToolName: part.FunctionCall.Name,
901 Input: string(input),
902 ProviderExecuted: false,
903 })
904 hasToolCalls = true
905 default:
906 // Silently skip unknown part types instead of erroring
907 // This allows for forward compatibility with new part types
908 }
909 }
910
911 if hasToolCalls {
912 finishReason = fantasy.FinishReasonToolCalls
913 } else {
914 finishReason = mapFinishReason(candidate.FinishReason)
915 }
916
917 return &fantasy.Response{
918 Content: content,
919 Usage: mapUsage(response.UsageMetadata),
920 FinishReason: finishReason,
921 Warnings: warnings,
922 }, nil
923}
924
925func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
926 switch reason {
927 case genai.FinishReasonStop:
928 return fantasy.FinishReasonStop
929 case genai.FinishReasonMaxTokens:
930 return fantasy.FinishReasonLength
931 case genai.FinishReasonSafety,
932 genai.FinishReasonBlocklist,
933 genai.FinishReasonProhibitedContent,
934 genai.FinishReasonSPII,
935 genai.FinishReasonImageSafety:
936 return fantasy.FinishReasonContentFilter
937 case genai.FinishReasonRecitation,
938 genai.FinishReasonLanguage,
939 genai.FinishReasonMalformedFunctionCall:
940 return fantasy.FinishReasonError
941 case genai.FinishReasonOther:
942 return fantasy.FinishReasonOther
943 default:
944 return fantasy.FinishReasonUnknown
945 }
946}
947
948func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
949 return fantasy.Usage{
950 InputTokens: int64(usage.ToolUsePromptTokenCount),
951 OutputTokens: int64(usage.CandidatesTokenCount),
952 TotalTokens: int64(usage.TotalTokenCount),
953 ReasoningTokens: int64(usage.ThoughtsTokenCount),
954 CacheCreationTokens: int64(usage.CachedContentTokenCount),
955 CacheReadTokens: 0,
956 }
957}