1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "reflect"
12 "strings"
13
14 "charm.land/fantasy"
15 "charm.land/fantasy/object"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/schema"
18 "cloud.google.com/go/auth"
19 "github.com/charmbracelet/x/exp/slice"
20 "github.com/google/uuid"
21 "google.golang.org/genai"
22)
23
24// Name is the name of the Google provider.
25const Name = "google"
26
27type provider struct {
28 options options
29}
30
31// ToolCallIDFunc defines a function that generates a tool call ID.
32type ToolCallIDFunc = func() string
33
34type options struct {
35 apiKey string
36 name string
37 baseURL string
38 headers map[string]string
39 client *http.Client
40 backend genai.Backend
41 project string
42 location string
43 skipAuth bool
44 toolCallIDFunc ToolCallIDFunc
45 objectMode fantasy.ObjectMode
46}
47
48// Option defines a function that configures Google provider options.
49type Option = func(*options)
50
51// New creates a new Google provider with the given options.
52func New(opts ...Option) (fantasy.Provider, error) {
53 options := options{
54 headers: map[string]string{},
55 toolCallIDFunc: func() string {
56 return uuid.NewString()
57 },
58 }
59 for _, o := range opts {
60 o(&options)
61 }
62
63 options.name = cmp.Or(options.name, Name)
64
65 return &provider{
66 options: options,
67 }, nil
68}
69
70// WithBaseURL sets the base URL for the Google provider.
71func WithBaseURL(baseURL string) Option {
72 return func(o *options) {
73 o.baseURL = baseURL
74 }
75}
76
77// WithGeminiAPIKey sets the Gemini API key for the Google provider.
78func WithGeminiAPIKey(apiKey string) Option {
79 return func(o *options) {
80 o.backend = genai.BackendGeminiAPI
81 o.apiKey = apiKey
82 o.project = ""
83 o.location = ""
84 }
85}
86
87// WithVertex configures the Google provider to use Vertex AI.
88func WithVertex(project, location string) Option {
89 if project == "" || location == "" {
90 panic("project and location must be provided")
91 }
92 return func(o *options) {
93 o.backend = genai.BackendVertexAI
94 o.apiKey = ""
95 o.project = project
96 o.location = location
97 }
98}
99
100// WithSkipAuth configures whether to skip authentication for the Google provider.
101func WithSkipAuth(skipAuth bool) Option {
102 return func(o *options) {
103 o.skipAuth = skipAuth
104 }
105}
106
107// WithName sets the name for the Google provider.
108func WithName(name string) Option {
109 return func(o *options) {
110 o.name = name
111 }
112}
113
114// WithHeaders sets the headers for the Google provider.
115func WithHeaders(headers map[string]string) Option {
116 return func(o *options) {
117 maps.Copy(o.headers, headers)
118 }
119}
120
121// WithHTTPClient sets the HTTP client for the Google provider.
122func WithHTTPClient(client *http.Client) Option {
123 return func(o *options) {
124 o.client = client
125 }
126}
127
128// WithToolCallIDFunc sets the function that generates a tool call ID.
129func WithToolCallIDFunc(f ToolCallIDFunc) Option {
130 return func(o *options) {
131 o.toolCallIDFunc = f
132 }
133}
134
135// WithObjectMode sets the object generation mode for the Google provider.
136func WithObjectMode(om fantasy.ObjectMode) Option {
137 return func(o *options) {
138 o.objectMode = om
139 }
140}
141
142func (*provider) Name() string {
143 return Name
144}
145
146type languageModel struct {
147 provider string
148 modelID string
149 client *genai.Client
150 providerOptions options
151 objectMode fantasy.ObjectMode
152}
153
154// LanguageModel implements fantasy.Provider.
155func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
156 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
157 p, err := anthropic.New(
158 anthropic.WithVertex(a.options.project, a.options.location),
159 anthropic.WithHTTPClient(a.options.client),
160 anthropic.WithSkipAuth(a.options.skipAuth),
161 )
162 if err != nil {
163 return nil, err
164 }
165 return p.LanguageModel(ctx, modelID)
166 }
167
168 cc := &genai.ClientConfig{
169 HTTPClient: a.options.client,
170 Backend: a.options.backend,
171 APIKey: a.options.apiKey,
172 Project: a.options.project,
173 Location: a.options.location,
174 }
175 if a.options.skipAuth {
176 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
177 } else if cc.Backend == genai.BackendVertexAI {
178 if err := cc.UseDefaultCredentials(); err != nil {
179 return nil, err
180 }
181 }
182
183 if a.options.baseURL != "" || len(a.options.headers) > 0 {
184 headers := http.Header{}
185 for k, v := range a.options.headers {
186 headers.Add(k, v)
187 }
188 cc.HTTPOptions = genai.HTTPOptions{
189 BaseURL: a.options.baseURL,
190 Headers: headers,
191 }
192 }
193 client, err := genai.NewClient(ctx, cc)
194 if err != nil {
195 return nil, err
196 }
197
198 objectMode := a.options.objectMode
199 if objectMode == "" {
200 objectMode = fantasy.ObjectModeAuto
201 }
202
203 return &languageModel{
204 modelID: modelID,
205 provider: a.options.name,
206 providerOptions: a.options,
207 client: client,
208 objectMode: objectMode,
209 }, nil
210}
211
212func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
213 config := &genai.GenerateContentConfig{}
214
215 providerOptions := &ProviderOptions{}
216 if v, ok := call.ProviderOptions[Name]; ok {
217 providerOptions, ok = v.(*ProviderOptions)
218 if !ok {
219 return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"}
220 }
221 }
222
223 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
224
225 if providerOptions.ThinkingConfig != nil {
226 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
227 *providerOptions.ThinkingConfig.IncludeThoughts &&
228 strings.HasPrefix(g.provider, "google.vertex.") {
229 warnings = append(warnings, fantasy.CallWarning{
230 Type: fantasy.CallWarningTypeOther,
231 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
232 "and might not be supported or could behave unexpectedly with the current Google provider " +
233 fmt.Sprintf("(%s)", g.provider),
234 })
235 }
236
237 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
238 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
239 warnings = append(warnings, fantasy.CallWarning{
240 Type: fantasy.CallWarningTypeOther,
241 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
242 })
243 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
244 }
245
246 if providerOptions.ThinkingConfig.ThinkingLevel != nil &&
247 providerOptions.ThinkingConfig.ThinkingBudget != nil {
248 return nil, nil, nil, &fantasy.Error{
249 Title: "invalid argument",
250 Message: "thinking_level and thinking_budget are mutually exclusive",
251 }
252 }
253 }
254
255 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
256
257 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
258 if len(content) > 0 && content[0].Role == genai.RoleUser {
259 systemParts := []string{}
260 for _, sp := range systemInstructions.Parts {
261 systemParts = append(systemParts, sp.Text)
262 }
263 systemMsg := strings.Join(systemParts, "\n")
264 content[0].Parts = append([]*genai.Part{
265 {
266 Text: systemMsg + "\n\n",
267 },
268 }, content[0].Parts...)
269 systemInstructions = nil
270 }
271 }
272
273 config.SystemInstruction = systemInstructions
274
275 if call.MaxOutputTokens != nil {
276 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
277 }
278
279 if call.Temperature != nil {
280 tmp := float32(*call.Temperature)
281 config.Temperature = &tmp
282 }
283 if call.TopK != nil {
284 tmp := float32(*call.TopK)
285 config.TopK = &tmp
286 }
287 if call.TopP != nil {
288 tmp := float32(*call.TopP)
289 config.TopP = &tmp
290 }
291 if call.FrequencyPenalty != nil {
292 tmp := float32(*call.FrequencyPenalty)
293 config.FrequencyPenalty = &tmp
294 }
295 if call.PresencePenalty != nil {
296 tmp := float32(*call.PresencePenalty)
297 config.PresencePenalty = &tmp
298 }
299
300 if providerOptions.ThinkingConfig != nil {
301 config.ThinkingConfig = &genai.ThinkingConfig{}
302 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
303 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
304 }
305 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
306 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
307 config.ThinkingConfig.ThinkingBudget = &tmp
308 }
309 if providerOptions.ThinkingConfig.ThinkingLevel != nil {
310 config.ThinkingConfig.ThinkingLevel = genai.ThinkingLevel(*providerOptions.ThinkingConfig.ThinkingLevel)
311 }
312 }
313 for _, safetySetting := range providerOptions.SafetySettings {
314 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
315 Category: genai.HarmCategory(safetySetting.Category),
316 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
317 })
318 }
319 if providerOptions.CachedContent != "" {
320 config.CachedContent = providerOptions.CachedContent
321 }
322
323 if len(call.Tools) > 0 {
324 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
325 config.ToolConfig = toolChoice
326 config.Tools = append(config.Tools, &genai.Tool{
327 FunctionDeclarations: tools,
328 })
329 warnings = append(warnings, toolWarnings...)
330 }
331
332 return config, content, warnings, nil
333}
334
335func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
336 var systemInstructions *genai.Content
337 var content []*genai.Content
338 var warnings []fantasy.CallWarning
339
340 finishedSystemBlock := false
341 for _, msg := range prompt {
342 switch msg.Role {
343 case fantasy.MessageRoleSystem:
344 if finishedSystemBlock {
345 // skip multiple system messages that are separated by user/assistant messages
346 // TODO: see if we need to send error here?
347 continue
348 }
349 finishedSystemBlock = true
350
351 var systemMessages []string
352 for _, part := range msg.Content {
353 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
354 if !ok || text.Text == "" {
355 continue
356 }
357 systemMessages = append(systemMessages, text.Text)
358 }
359 if len(systemMessages) > 0 {
360 systemInstructions = &genai.Content{
361 Parts: []*genai.Part{
362 {
363 Text: strings.Join(systemMessages, "\n"),
364 },
365 },
366 }
367 }
368 case fantasy.MessageRoleUser:
369 var parts []*genai.Part
370 for _, part := range msg.Content {
371 switch part.GetType() {
372 case fantasy.ContentTypeText:
373 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
374 if !ok || text.Text == "" {
375 continue
376 }
377 parts = append(parts, &genai.Part{
378 Text: text.Text,
379 })
380 case fantasy.ContentTypeFile:
381 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
382 if !ok {
383 continue
384 }
385 parts = append(parts, &genai.Part{
386 InlineData: &genai.Blob{
387 Data: file.Data,
388 MIMEType: file.MediaType,
389 },
390 })
391 }
392 }
393 if len(parts) > 0 {
394 content = append(content, &genai.Content{
395 Role: genai.RoleUser,
396 Parts: parts,
397 })
398 }
399 case fantasy.MessageRoleAssistant:
400 var parts []*genai.Part
401 var currentReasoningMetadata *ReasoningMetadata
402 for _, part := range msg.Content {
403 switch part.GetType() {
404 case fantasy.ContentTypeReasoning:
405 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
406 if !ok {
407 continue
408 }
409
410 metadata, ok := reasoning.ProviderOptions[Name]
411 if !ok {
412 continue
413 }
414 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
415 if !ok {
416 continue
417 }
418 currentReasoningMetadata = reasoningMetadata
419 case fantasy.ContentTypeText:
420 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
421 if !ok || text.Text == "" {
422 continue
423 }
424 geminiPart := &genai.Part{
425 Text: text.Text,
426 }
427 if currentReasoningMetadata != nil {
428 geminiPart.ThoughtSignature = []byte(currentReasoningMetadata.Signature)
429 currentReasoningMetadata = nil
430 }
431 parts = append(parts, geminiPart)
432 case fantasy.ContentTypeToolCall:
433 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
434 if !ok {
435 continue
436 }
437
438 var result map[string]any
439 err := json.Unmarshal([]byte(toolCall.Input), &result)
440 if err != nil {
441 continue
442 }
443 geminiPart := &genai.Part{
444 FunctionCall: &genai.FunctionCall{
445 ID: toolCall.ToolCallID,
446 Name: toolCall.ToolName,
447 Args: result,
448 },
449 }
450 if currentReasoningMetadata != nil {
451 geminiPart.ThoughtSignature = []byte(currentReasoningMetadata.Signature)
452 currentReasoningMetadata = nil
453 }
454 parts = append(parts, geminiPart)
455 }
456 }
457 if len(parts) > 0 {
458 content = append(content, &genai.Content{
459 Role: genai.RoleModel,
460 Parts: parts,
461 })
462 }
463 case fantasy.MessageRoleTool:
464 var parts []*genai.Part
465 for _, part := range msg.Content {
466 switch part.GetType() {
467 case fantasy.ContentTypeToolResult:
468 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
469 if !ok {
470 continue
471 }
472 var toolCall fantasy.ToolCallPart
473 for _, m := range prompt {
474 if m.Role == fantasy.MessageRoleAssistant {
475 for _, content := range m.Content {
476 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
477 if !ok {
478 continue
479 }
480 if tc.ToolCallID == result.ToolCallID {
481 toolCall = tc
482 break
483 }
484 }
485 }
486 }
487 switch result.Output.GetType() {
488 case fantasy.ToolResultContentTypeText:
489 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
490 if !ok {
491 continue
492 }
493 response := map[string]any{"result": content.Text}
494 parts = append(parts, &genai.Part{
495 FunctionResponse: &genai.FunctionResponse{
496 ID: result.ToolCallID,
497 Response: response,
498 Name: toolCall.ToolName,
499 },
500 })
501
502 case fantasy.ToolResultContentTypeError:
503 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
504 if !ok {
505 continue
506 }
507 response := map[string]any{"result": content.Error.Error()}
508 parts = append(parts, &genai.Part{
509 FunctionResponse: &genai.FunctionResponse{
510 ID: result.ToolCallID,
511 Response: response,
512 Name: toolCall.ToolName,
513 },
514 })
515 }
516 }
517 }
518 if len(parts) > 0 {
519 content = append(content, &genai.Content{
520 Role: genai.RoleUser,
521 Parts: parts,
522 })
523 }
524 default:
525 panic("unsupported message role: " + msg.Role)
526 }
527 }
528 return systemInstructions, content, warnings
529}
530
531// Generate implements fantasy.LanguageModel.
532func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
533 config, contents, warnings, err := g.prepareParams(call)
534 if err != nil {
535 return nil, err
536 }
537
538 lastMessage, history, ok := slice.Pop(contents)
539 if !ok {
540 return nil, errors.New("no messages to send")
541 }
542
543 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
544 if err != nil {
545 return nil, err
546 }
547
548 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
549 if err != nil {
550 return nil, toProviderErr(err)
551 }
552
553 return g.mapResponse(response, warnings)
554}
555
556// Model implements fantasy.LanguageModel.
557func (g *languageModel) Model() string {
558 return g.modelID
559}
560
561// Provider implements fantasy.LanguageModel.
562func (g *languageModel) Provider() string {
563 return g.provider
564}
565
566// Stream implements fantasy.LanguageModel.
567func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
568 config, contents, warnings, err := g.prepareParams(call)
569 if err != nil {
570 return nil, err
571 }
572
573 lastMessage, history, ok := slice.Pop(contents)
574 if !ok {
575 return nil, errors.New("no messages to send")
576 }
577
578 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
579 if err != nil {
580 return nil, err
581 }
582
583 return func(yield func(fantasy.StreamPart) bool) {
584 if len(warnings) > 0 {
585 if !yield(fantasy.StreamPart{
586 Type: fantasy.StreamPartTypeWarnings,
587 Warnings: warnings,
588 }) {
589 return
590 }
591 }
592
593 var currentContent string
594 var toolCalls []fantasy.ToolCallContent
595 var isActiveText bool
596 var isActiveReasoning bool
597 var blockCounter int
598 var currentTextBlockID string
599 var currentReasoningBlockID string
600 var usage *fantasy.Usage
601 var lastFinishReason fantasy.FinishReason
602
603 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
604 if err != nil {
605 yield(fantasy.StreamPart{
606 Type: fantasy.StreamPartTypeError,
607 Error: toProviderErr(err),
608 })
609 return
610 }
611
612 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
613 for _, part := range resp.Candidates[0].Content.Parts {
614 switch {
615 case part.Text != "":
616 delta := part.Text
617 if delta != "" {
618 // Check if this is a reasoning/thought part
619 if part.Thought {
620 // End any active text block before starting reasoning
621 if isActiveText {
622 isActiveText = false
623 if !yield(fantasy.StreamPart{
624 Type: fantasy.StreamPartTypeTextEnd,
625 ID: currentTextBlockID,
626 }) {
627 return
628 }
629 }
630
631 // Start new reasoning block if not already active
632 if !isActiveReasoning {
633 isActiveReasoning = true
634 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
635 blockCounter++
636 if !yield(fantasy.StreamPart{
637 Type: fantasy.StreamPartTypeReasoningStart,
638 ID: currentReasoningBlockID,
639 }) {
640 return
641 }
642 }
643
644 if !yield(fantasy.StreamPart{
645 Type: fantasy.StreamPartTypeReasoningDelta,
646 ID: currentReasoningBlockID,
647 Delta: delta,
648 }) {
649 return
650 }
651 } else {
652 // Start new text block if not already active
653 if !isActiveText {
654 isActiveText = true
655 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
656 blockCounter++
657 if !yield(fantasy.StreamPart{
658 Type: fantasy.StreamPartTypeTextStart,
659 ID: currentTextBlockID,
660 }) {
661 return
662 }
663 }
664 // End any active reasoning block before starting text
665 if isActiveReasoning {
666 isActiveReasoning = false
667 metadata := &ReasoningMetadata{
668 Signature: string(part.ThoughtSignature),
669 }
670 if !yield(fantasy.StreamPart{
671 Type: fantasy.StreamPartTypeReasoningEnd,
672 ID: currentReasoningBlockID,
673 ProviderMetadata: fantasy.ProviderMetadata{
674 Name: metadata,
675 },
676 }) {
677 return
678 }
679 } else if part.ThoughtSignature != nil {
680 metadata := &ReasoningMetadata{
681 Signature: string(part.ThoughtSignature),
682 }
683
684 if !yield(fantasy.StreamPart{
685 Type: fantasy.StreamPartTypeReasoningStart,
686 ID: currentReasoningBlockID,
687 }) {
688 return
689 }
690 if !yield(fantasy.StreamPart{
691 Type: fantasy.StreamPartTypeReasoningEnd,
692 ID: currentReasoningBlockID,
693 ProviderMetadata: fantasy.ProviderMetadata{
694 Name: metadata,
695 },
696 }) {
697 return
698 }
699 }
700
701 if !yield(fantasy.StreamPart{
702 Type: fantasy.StreamPartTypeTextDelta,
703 ID: currentTextBlockID,
704 Delta: delta,
705 }) {
706 return
707 }
708 currentContent += delta
709 }
710 }
711 case part.FunctionCall != nil:
712 // End any active text or reasoning blocks
713 if isActiveText {
714 isActiveText = false
715 if !yield(fantasy.StreamPart{
716 Type: fantasy.StreamPartTypeTextEnd,
717 ID: currentTextBlockID,
718 }) {
719 return
720 }
721 }
722 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
723 // End any active reasoning block before starting text
724 if isActiveReasoning {
725 isActiveReasoning = false
726 metadata := &ReasoningMetadata{
727 Signature: string(part.ThoughtSignature),
728 ToolID: toolCallID,
729 }
730 if !yield(fantasy.StreamPart{
731 Type: fantasy.StreamPartTypeReasoningEnd,
732 ID: currentReasoningBlockID,
733 ProviderMetadata: fantasy.ProviderMetadata{
734 Name: metadata,
735 },
736 }) {
737 return
738 }
739 } else if part.ThoughtSignature != nil {
740 metadata := &ReasoningMetadata{
741 Signature: string(part.ThoughtSignature),
742 ToolID: toolCallID,
743 }
744
745 if !yield(fantasy.StreamPart{
746 Type: fantasy.StreamPartTypeReasoningStart,
747 ID: currentReasoningBlockID,
748 }) {
749 return
750 }
751 if !yield(fantasy.StreamPart{
752 Type: fantasy.StreamPartTypeReasoningEnd,
753 ID: currentReasoningBlockID,
754 ProviderMetadata: fantasy.ProviderMetadata{
755 Name: metadata,
756 },
757 }) {
758 return
759 }
760 }
761 args, err := json.Marshal(part.FunctionCall.Args)
762 if err != nil {
763 yield(fantasy.StreamPart{
764 Type: fantasy.StreamPartTypeError,
765 Error: err,
766 })
767 return
768 }
769
770 if !yield(fantasy.StreamPart{
771 Type: fantasy.StreamPartTypeToolInputStart,
772 ID: toolCallID,
773 ToolCallName: part.FunctionCall.Name,
774 }) {
775 return
776 }
777
778 if !yield(fantasy.StreamPart{
779 Type: fantasy.StreamPartTypeToolInputDelta,
780 ID: toolCallID,
781 Delta: string(args),
782 }) {
783 return
784 }
785
786 if !yield(fantasy.StreamPart{
787 Type: fantasy.StreamPartTypeToolInputEnd,
788 ID: toolCallID,
789 }) {
790 return
791 }
792
793 if !yield(fantasy.StreamPart{
794 Type: fantasy.StreamPartTypeToolCall,
795 ID: toolCallID,
796 ToolCallName: part.FunctionCall.Name,
797 ToolCallInput: string(args),
798 ProviderExecuted: false,
799 }) {
800 return
801 }
802
803 toolCalls = append(toolCalls, fantasy.ToolCallContent{
804 ToolCallID: toolCallID,
805 ToolName: part.FunctionCall.Name,
806 Input: string(args),
807 ProviderExecuted: false,
808 })
809 }
810 }
811 }
812
813 // we need to make sure that there is actual tokendata
814 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
815 currentUsage := mapUsage(resp.UsageMetadata)
816 // if first usage chunk
817 if usage == nil {
818 usage = ¤tUsage
819 } else {
820 usage.OutputTokens += currentUsage.OutputTokens
821 usage.ReasoningTokens += currentUsage.ReasoningTokens
822 usage.CacheReadTokens += currentUsage.CacheReadTokens
823 }
824 }
825
826 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
827 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
828 }
829 }
830
831 // Close any open blocks before finishing
832 if isActiveText {
833 if !yield(fantasy.StreamPart{
834 Type: fantasy.StreamPartTypeTextEnd,
835 ID: currentTextBlockID,
836 }) {
837 return
838 }
839 }
840 if isActiveReasoning {
841 if !yield(fantasy.StreamPart{
842 Type: fantasy.StreamPartTypeReasoningEnd,
843 ID: currentReasoningBlockID,
844 }) {
845 return
846 }
847 }
848
849 finishReason := lastFinishReason
850 if len(toolCalls) > 0 {
851 finishReason = fantasy.FinishReasonToolCalls
852 } else if finishReason == "" {
853 finishReason = fantasy.FinishReasonStop
854 }
855
856 var finalUsage fantasy.Usage
857 if usage != nil {
858 finalUsage = *usage
859 }
860
861 yield(fantasy.StreamPart{
862 Type: fantasy.StreamPartTypeFinish,
863 Usage: finalUsage,
864 FinishReason: finishReason,
865 })
866 }, nil
867}
868
869// GenerateObject implements fantasy.LanguageModel.
870func (g *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
871 switch g.objectMode {
872 case fantasy.ObjectModeText:
873 return object.GenerateWithText(ctx, g, call)
874 case fantasy.ObjectModeTool:
875 return object.GenerateWithTool(ctx, g, call)
876 default:
877 return g.generateObjectWithJSONMode(ctx, call)
878 }
879}
880
881// StreamObject implements fantasy.LanguageModel.
882func (g *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
883 switch g.objectMode {
884 case fantasy.ObjectModeTool:
885 return object.StreamWithTool(ctx, g, call)
886 case fantasy.ObjectModeText:
887 return object.StreamWithText(ctx, g, call)
888 default:
889 return g.streamObjectWithJSONMode(ctx, call)
890 }
891}
892
893func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
894 // Convert our Schema to Google's JSON Schema format
895 jsonSchemaMap := schema.ToMap(call.Schema)
896
897 // Build request using prepareParams
898 fantasyCall := fantasy.Call{
899 Prompt: call.Prompt,
900 MaxOutputTokens: call.MaxOutputTokens,
901 Temperature: call.Temperature,
902 TopP: call.TopP,
903 TopK: call.TopK,
904 PresencePenalty: call.PresencePenalty,
905 FrequencyPenalty: call.FrequencyPenalty,
906 ProviderOptions: call.ProviderOptions,
907 }
908
909 config, contents, warnings, err := g.prepareParams(fantasyCall)
910 if err != nil {
911 return nil, err
912 }
913
914 // Set ResponseMIMEType and ResponseJsonSchema for structured output
915 config.ResponseMIMEType = "application/json"
916 config.ResponseJsonSchema = jsonSchemaMap
917
918 lastMessage, history, ok := slice.Pop(contents)
919 if !ok {
920 return nil, errors.New("no messages to send")
921 }
922
923 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
924 if err != nil {
925 return nil, err
926 }
927
928 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
929 if err != nil {
930 return nil, toProviderErr(err)
931 }
932
933 mappedResponse, err := g.mapResponse(response, warnings)
934 if err != nil {
935 return nil, err
936 }
937
938 jsonText := mappedResponse.Content.Text()
939 if jsonText == "" {
940 return nil, &fantasy.NoObjectGeneratedError{
941 RawText: "",
942 ParseError: fmt.Errorf("no text content in response"),
943 Usage: mappedResponse.Usage,
944 FinishReason: mappedResponse.FinishReason,
945 }
946 }
947
948 // Parse and validate
949 var obj any
950 if call.RepairText != nil {
951 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
952 } else {
953 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
954 }
955
956 if err != nil {
957 // Add usage info to error
958 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
959 nogErr.Usage = mappedResponse.Usage
960 nogErr.FinishReason = mappedResponse.FinishReason
961 }
962 return nil, err
963 }
964
965 return &fantasy.ObjectResponse{
966 Object: obj,
967 RawText: jsonText,
968 Usage: mappedResponse.Usage,
969 FinishReason: mappedResponse.FinishReason,
970 Warnings: warnings,
971 ProviderMetadata: mappedResponse.ProviderMetadata,
972 }, nil
973}
974
975func (g *languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
976 // Convert our Schema to Google's JSON Schema format
977 jsonSchemaMap := schema.ToMap(call.Schema)
978
979 // Build request using prepareParams
980 fantasyCall := fantasy.Call{
981 Prompt: call.Prompt,
982 MaxOutputTokens: call.MaxOutputTokens,
983 Temperature: call.Temperature,
984 TopP: call.TopP,
985 TopK: call.TopK,
986 PresencePenalty: call.PresencePenalty,
987 FrequencyPenalty: call.FrequencyPenalty,
988 ProviderOptions: call.ProviderOptions,
989 }
990
991 config, contents, warnings, err := g.prepareParams(fantasyCall)
992 if err != nil {
993 return nil, err
994 }
995
996 // Set ResponseMIMEType and ResponseJsonSchema for structured output
997 config.ResponseMIMEType = "application/json"
998 config.ResponseJsonSchema = jsonSchemaMap
999
1000 lastMessage, history, ok := slice.Pop(contents)
1001 if !ok {
1002 return nil, errors.New("no messages to send")
1003 }
1004
1005 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
1006 if err != nil {
1007 return nil, err
1008 }
1009
1010 return func(yield func(fantasy.ObjectStreamPart) bool) {
1011 if len(warnings) > 0 {
1012 if !yield(fantasy.ObjectStreamPart{
1013 Type: fantasy.ObjectStreamPartTypeObject,
1014 Warnings: warnings,
1015 }) {
1016 return
1017 }
1018 }
1019
1020 var accumulated string
1021 var lastParsedObject any
1022 var usage *fantasy.Usage
1023 var lastFinishReason fantasy.FinishReason
1024 var streamErr error
1025
1026 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
1027 if err != nil {
1028 streamErr = toProviderErr(err)
1029 yield(fantasy.ObjectStreamPart{
1030 Type: fantasy.ObjectStreamPartTypeError,
1031 Error: streamErr,
1032 })
1033 return
1034 }
1035
1036 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
1037 for _, part := range resp.Candidates[0].Content.Parts {
1038 if part.Text != "" && !part.Thought {
1039 accumulated += part.Text
1040
1041 // Try to parse the accumulated text
1042 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
1043
1044 // If we successfully parsed, validate and emit
1045 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
1046 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
1047 // Only emit if object is different from last
1048 if !reflect.DeepEqual(obj, lastParsedObject) {
1049 if !yield(fantasy.ObjectStreamPart{
1050 Type: fantasy.ObjectStreamPartTypeObject,
1051 Object: obj,
1052 }) {
1053 return
1054 }
1055 lastParsedObject = obj
1056 }
1057 }
1058 }
1059
1060 // If parsing failed and we have a repair function, try it
1061 if state == schema.ParseStateFailed && call.RepairText != nil {
1062 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
1063 if repairErr == nil {
1064 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
1065 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
1066 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
1067 if !reflect.DeepEqual(obj2, lastParsedObject) {
1068 if !yield(fantasy.ObjectStreamPart{
1069 Type: fantasy.ObjectStreamPartTypeObject,
1070 Object: obj2,
1071 }) {
1072 return
1073 }
1074 lastParsedObject = obj2
1075 }
1076 }
1077 }
1078 }
1079 }
1080 }
1081 }
1082
1083 // we need to make sure that there is actual tokendata
1084 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
1085 currentUsage := mapUsage(resp.UsageMetadata)
1086 if usage == nil {
1087 usage = ¤tUsage
1088 } else {
1089 usage.OutputTokens += currentUsage.OutputTokens
1090 usage.ReasoningTokens += currentUsage.ReasoningTokens
1091 usage.CacheReadTokens += currentUsage.CacheReadTokens
1092 }
1093 }
1094
1095 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
1096 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
1097 }
1098 }
1099
1100 // Final validation and emit
1101 if streamErr == nil && lastParsedObject != nil {
1102 finishReason := cmp.Or(lastFinishReason, fantasy.FinishReasonStop)
1103
1104 var finalUsage fantasy.Usage
1105 if usage != nil {
1106 finalUsage = *usage
1107 }
1108
1109 yield(fantasy.ObjectStreamPart{
1110 Type: fantasy.ObjectStreamPartTypeFinish,
1111 Usage: finalUsage,
1112 FinishReason: finishReason,
1113 })
1114 } else if streamErr == nil && lastParsedObject == nil {
1115 // No object was generated
1116 var finalUsage fantasy.Usage
1117 if usage != nil {
1118 finalUsage = *usage
1119 }
1120 yield(fantasy.ObjectStreamPart{
1121 Type: fantasy.ObjectStreamPartTypeError,
1122 Error: &fantasy.NoObjectGeneratedError{
1123 RawText: accumulated,
1124 ParseError: fmt.Errorf("no valid object generated in stream"),
1125 Usage: finalUsage,
1126 FinishReason: lastFinishReason,
1127 },
1128 })
1129 }
1130 }, nil
1131}
1132
1133func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
1134 for _, tool := range tools {
1135 if tool.GetType() == fantasy.ToolTypeFunction {
1136 ft, ok := tool.(fantasy.FunctionTool)
1137 if !ok {
1138 continue
1139 }
1140
1141 var required []string
1142 var properties map[string]any
1143 if props, ok := ft.InputSchema["properties"]; ok {
1144 properties, _ = props.(map[string]any)
1145 }
1146 if req, ok := ft.InputSchema["required"]; ok {
1147 if reqArr, ok := req.([]string); ok {
1148 required = reqArr
1149 }
1150 }
1151 declaration := &genai.FunctionDeclaration{
1152 Name: ft.Name,
1153 Description: ft.Description,
1154 Parameters: &genai.Schema{
1155 Type: genai.TypeObject,
1156 Properties: convertSchemaProperties(properties),
1157 Required: required,
1158 },
1159 }
1160 googleTools = append(googleTools, declaration)
1161 continue
1162 }
1163 // TODO: handle provider tool calls
1164 warnings = append(warnings, fantasy.CallWarning{
1165 Type: fantasy.CallWarningTypeUnsupportedTool,
1166 Tool: tool,
1167 Message: "tool is not supported",
1168 })
1169 }
1170 if toolChoice == nil {
1171 return googleTools, googleToolChoice, warnings
1172 }
1173 switch *toolChoice {
1174 case fantasy.ToolChoiceAuto:
1175 googleToolChoice = &genai.ToolConfig{
1176 FunctionCallingConfig: &genai.FunctionCallingConfig{
1177 Mode: genai.FunctionCallingConfigModeAuto,
1178 },
1179 }
1180 case fantasy.ToolChoiceRequired:
1181 googleToolChoice = &genai.ToolConfig{
1182 FunctionCallingConfig: &genai.FunctionCallingConfig{
1183 Mode: genai.FunctionCallingConfigModeAny,
1184 },
1185 }
1186 case fantasy.ToolChoiceNone:
1187 googleToolChoice = &genai.ToolConfig{
1188 FunctionCallingConfig: &genai.FunctionCallingConfig{
1189 Mode: genai.FunctionCallingConfigModeNone,
1190 },
1191 }
1192 default:
1193 googleToolChoice = &genai.ToolConfig{
1194 FunctionCallingConfig: &genai.FunctionCallingConfig{
1195 Mode: genai.FunctionCallingConfigModeAny,
1196 AllowedFunctionNames: []string{
1197 string(*toolChoice),
1198 },
1199 },
1200 }
1201 }
1202 return googleTools, googleToolChoice, warnings
1203}
1204
1205func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
1206 properties := make(map[string]*genai.Schema)
1207
1208 for name, param := range parameters {
1209 properties[name] = convertToSchema(param)
1210 }
1211
1212 return properties
1213}
1214
1215func convertToSchema(param any) *genai.Schema {
1216 schema := &genai.Schema{Type: genai.TypeString}
1217
1218 paramMap, ok := param.(map[string]any)
1219 if !ok {
1220 return schema
1221 }
1222
1223 if desc, ok := paramMap["description"].(string); ok {
1224 schema.Description = desc
1225 }
1226
1227 typeVal, hasType := paramMap["type"]
1228 if !hasType {
1229 return schema
1230 }
1231
1232 typeStr, ok := typeVal.(string)
1233 if !ok {
1234 return schema
1235 }
1236
1237 schema.Type = mapJSONTypeToGoogle(typeStr)
1238
1239 switch typeStr {
1240 case "array":
1241 schema.Items = processArrayItems(paramMap)
1242 case "object":
1243 if props, ok := paramMap["properties"].(map[string]any); ok {
1244 schema.Properties = convertSchemaProperties(props)
1245 }
1246 }
1247
1248 return schema
1249}
1250
1251func processArrayItems(paramMap map[string]any) *genai.Schema {
1252 items, ok := paramMap["items"].(map[string]any)
1253 if !ok {
1254 return nil
1255 }
1256
1257 return convertToSchema(items)
1258}
1259
1260func mapJSONTypeToGoogle(jsonType string) genai.Type {
1261 switch jsonType {
1262 case "string":
1263 return genai.TypeString
1264 case "number":
1265 return genai.TypeNumber
1266 case "integer":
1267 return genai.TypeInteger
1268 case "boolean":
1269 return genai.TypeBoolean
1270 case "array":
1271 return genai.TypeArray
1272 case "object":
1273 return genai.TypeObject
1274 default:
1275 return genai.TypeString // Default to string for unknown types
1276 }
1277}
1278
1279func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
1280 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
1281 return nil, errors.New("no response from model")
1282 }
1283
1284 var (
1285 content []fantasy.Content
1286 finishReason fantasy.FinishReason
1287 hasToolCalls bool
1288 candidate = response.Candidates[0]
1289 )
1290
1291 for _, part := range candidate.Content.Parts {
1292 switch {
1293 case part.Text != "":
1294 if part.Thought {
1295 reasoningContent := fantasy.ReasoningContent{Text: part.Text}
1296 if part.ThoughtSignature != nil {
1297 metadata := &ReasoningMetadata{
1298 Signature: string(part.ThoughtSignature),
1299 }
1300 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1301 Name: metadata,
1302 }
1303 }
1304 content = append(content, reasoningContent)
1305 } else {
1306 foundReasoning := false
1307 if part.ThoughtSignature != nil {
1308 metadata := &ReasoningMetadata{
1309 Signature: string(part.ThoughtSignature),
1310 }
1311 // find the last reasoning content and add the signature
1312 for i := len(content) - 1; i >= 0; i-- {
1313 c := content[i]
1314 if c.GetType() == fantasy.ContentTypeReasoning {
1315 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
1316 if !ok {
1317 continue
1318 }
1319 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1320 Name: metadata,
1321 }
1322 content[i] = reasoningContent
1323 foundReasoning = true
1324 break
1325 }
1326 }
1327 if !foundReasoning {
1328 content = append(content, fantasy.ReasoningContent{
1329 ProviderMetadata: fantasy.ProviderMetadata{
1330 Name: metadata,
1331 },
1332 })
1333 }
1334 }
1335 content = append(content, fantasy.TextContent{Text: part.Text})
1336 }
1337 case part.FunctionCall != nil:
1338 input, err := json.Marshal(part.FunctionCall.Args)
1339 if err != nil {
1340 return nil, err
1341 }
1342 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
1343 foundReasoning := false
1344 if part.ThoughtSignature != nil {
1345 metadata := &ReasoningMetadata{
1346 Signature: string(part.ThoughtSignature),
1347 ToolID: toolCallID,
1348 }
1349 // find the last reasoning content and add the signature
1350 for i := len(content) - 1; i >= 0; i-- {
1351 c := content[i]
1352 if c.GetType() == fantasy.ContentTypeReasoning {
1353 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
1354 if !ok {
1355 continue
1356 }
1357 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1358 Name: metadata,
1359 }
1360 content[i] = reasoningContent
1361 foundReasoning = true
1362 break
1363 }
1364 }
1365 if !foundReasoning {
1366 content = append(content, fantasy.ReasoningContent{
1367 ProviderMetadata: fantasy.ProviderMetadata{
1368 Name: metadata,
1369 },
1370 })
1371 }
1372 }
1373 content = append(content, fantasy.ToolCallContent{
1374 ToolCallID: toolCallID,
1375 ToolName: part.FunctionCall.Name,
1376 Input: string(input),
1377 ProviderExecuted: false,
1378 })
1379 hasToolCalls = true
1380 default:
1381 // Silently skip unknown part types instead of erroring
1382 // This allows for forward compatibility with new part types
1383 }
1384 }
1385
1386 if hasToolCalls {
1387 finishReason = fantasy.FinishReasonToolCalls
1388 } else {
1389 finishReason = mapFinishReason(candidate.FinishReason)
1390 }
1391
1392 return &fantasy.Response{
1393 Content: content,
1394 Usage: mapUsage(response.UsageMetadata),
1395 FinishReason: finishReason,
1396 Warnings: warnings,
1397 }, nil
1398}
1399
1400// GetReasoningMetadata extracts reasoning metadata from provider options for google models.
1401func GetReasoningMetadata(providerOptions fantasy.ProviderOptions) *ReasoningMetadata {
1402 if googleOptions, ok := providerOptions[Name]; ok {
1403 if reasoning, ok := googleOptions.(*ReasoningMetadata); ok {
1404 return reasoning
1405 }
1406 }
1407 return nil
1408}
1409
1410func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
1411 switch reason {
1412 case genai.FinishReasonStop:
1413 return fantasy.FinishReasonStop
1414 case genai.FinishReasonMaxTokens:
1415 return fantasy.FinishReasonLength
1416 case genai.FinishReasonSafety,
1417 genai.FinishReasonBlocklist,
1418 genai.FinishReasonProhibitedContent,
1419 genai.FinishReasonSPII,
1420 genai.FinishReasonImageSafety:
1421 return fantasy.FinishReasonContentFilter
1422 case genai.FinishReasonRecitation,
1423 genai.FinishReasonLanguage,
1424 genai.FinishReasonMalformedFunctionCall:
1425 return fantasy.FinishReasonError
1426 case genai.FinishReasonOther:
1427 return fantasy.FinishReasonOther
1428 default:
1429 return fantasy.FinishReasonUnknown
1430 }
1431}
1432
1433func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1434 return fantasy.Usage{
1435 InputTokens: int64(usage.PromptTokenCount),
1436 OutputTokens: int64(usage.CandidatesTokenCount),
1437 TotalTokens: int64(usage.TotalTokenCount),
1438 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1439 CacheCreationTokens: 0,
1440 CacheReadTokens: int64(usage.CachedContentTokenCount),
1441 }
1442}