1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "reflect"
12 "strings"
13
14 "charm.land/fantasy"
15 "charm.land/fantasy/object"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/schema"
18 "cloud.google.com/go/auth"
19 "github.com/charmbracelet/x/exp/slice"
20 "github.com/google/uuid"
21 "google.golang.org/genai"
22)
23
24// Name is the name of the Google provider.
25const Name = "google"
26
27type provider struct {
28 options options
29}
30
31// ToolCallIDFunc defines a function that generates a tool call ID.
32type ToolCallIDFunc = func() string
33
34type options struct {
35 apiKey string
36 name string
37 baseURL string
38 headers map[string]string
39 client *http.Client
40 backend genai.Backend
41 project string
42 location string
43 skipAuth bool
44 toolCallIDFunc ToolCallIDFunc
45 objectMode fantasy.ObjectMode
46}
47
48// Option defines a function that configures Google provider options.
49type Option = func(*options)
50
51// New creates a new Google provider with the given options.
52func New(opts ...Option) (fantasy.Provider, error) {
53 options := options{
54 headers: map[string]string{},
55 toolCallIDFunc: func() string {
56 return uuid.NewString()
57 },
58 }
59 for _, o := range opts {
60 o(&options)
61 }
62
63 options.name = cmp.Or(options.name, Name)
64
65 return &provider{
66 options: options,
67 }, nil
68}
69
70// WithBaseURL sets the base URL for the Google provider.
71func WithBaseURL(baseURL string) Option {
72 return func(o *options) {
73 o.baseURL = baseURL
74 }
75}
76
77// WithGeminiAPIKey sets the Gemini API key for the Google provider.
78func WithGeminiAPIKey(apiKey string) Option {
79 return func(o *options) {
80 o.backend = genai.BackendGeminiAPI
81 o.apiKey = apiKey
82 o.project = ""
83 o.location = ""
84 }
85}
86
87// WithVertex configures the Google provider to use Vertex AI.
88func WithVertex(project, location string) Option {
89 if project == "" || location == "" {
90 panic("project and location must be provided")
91 }
92 return func(o *options) {
93 o.backend = genai.BackendVertexAI
94 o.apiKey = ""
95 o.project = project
96 o.location = location
97 }
98}
99
100// WithSkipAuth configures whether to skip authentication for the Google provider.
101func WithSkipAuth(skipAuth bool) Option {
102 return func(o *options) {
103 o.skipAuth = skipAuth
104 }
105}
106
107// WithName sets the name for the Google provider.
108func WithName(name string) Option {
109 return func(o *options) {
110 o.name = name
111 }
112}
113
114// WithHeaders sets the headers for the Google provider.
115func WithHeaders(headers map[string]string) Option {
116 return func(o *options) {
117 maps.Copy(o.headers, headers)
118 }
119}
120
121// WithHTTPClient sets the HTTP client for the Google provider.
122func WithHTTPClient(client *http.Client) Option {
123 return func(o *options) {
124 o.client = client
125 }
126}
127
128// WithToolCallIDFunc sets the function that generates a tool call ID.
129func WithToolCallIDFunc(f ToolCallIDFunc) Option {
130 return func(o *options) {
131 o.toolCallIDFunc = f
132 }
133}
134
135// WithObjectMode sets the object generation mode for the Google provider.
136func WithObjectMode(om fantasy.ObjectMode) Option {
137 return func(o *options) {
138 o.objectMode = om
139 }
140}
141
142func (*provider) Name() string {
143 return Name
144}
145
146type languageModel struct {
147 provider string
148 modelID string
149 client *genai.Client
150 providerOptions options
151 objectMode fantasy.ObjectMode
152}
153
154// LanguageModel implements fantasy.Provider.
155func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
156 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
157 p, err := anthropic.New(
158 anthropic.WithVertex(a.options.project, a.options.location),
159 anthropic.WithHTTPClient(a.options.client),
160 anthropic.WithSkipAuth(a.options.skipAuth),
161 )
162 if err != nil {
163 return nil, err
164 }
165 return p.LanguageModel(ctx, modelID)
166 }
167
168 cc := &genai.ClientConfig{
169 HTTPClient: a.options.client,
170 Backend: a.options.backend,
171 APIKey: a.options.apiKey,
172 Project: a.options.project,
173 Location: a.options.location,
174 }
175 if a.options.skipAuth {
176 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
177 } else if cc.Backend == genai.BackendVertexAI {
178 if err := cc.UseDefaultCredentials(); err != nil {
179 return nil, err
180 }
181 }
182
183 if a.options.baseURL != "" || len(a.options.headers) > 0 {
184 headers := http.Header{}
185 for k, v := range a.options.headers {
186 headers.Add(k, v)
187 }
188 cc.HTTPOptions = genai.HTTPOptions{
189 BaseURL: a.options.baseURL,
190 Headers: headers,
191 }
192 }
193 client, err := genai.NewClient(ctx, cc)
194 if err != nil {
195 return nil, err
196 }
197
198 objectMode := a.options.objectMode
199 if objectMode == "" {
200 objectMode = fantasy.ObjectModeAuto
201 }
202
203 return &languageModel{
204 modelID: modelID,
205 provider: a.options.name,
206 providerOptions: a.options,
207 client: client,
208 objectMode: objectMode,
209 }, nil
210}
211
212func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
213 config := &genai.GenerateContentConfig{}
214
215 providerOptions := &ProviderOptions{}
216 if v, ok := call.ProviderOptions[Name]; ok {
217 providerOptions, ok = v.(*ProviderOptions)
218 if !ok {
219 return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"}
220 }
221 }
222
223 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
224
225 if providerOptions.ThinkingConfig != nil {
226 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
227 *providerOptions.ThinkingConfig.IncludeThoughts &&
228 strings.HasPrefix(g.provider, "google.vertex.") {
229 warnings = append(warnings, fantasy.CallWarning{
230 Type: fantasy.CallWarningTypeOther,
231 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
232 "and might not be supported or could behave unexpectedly with the current Google provider " +
233 fmt.Sprintf("(%s)", g.provider),
234 })
235 }
236
237 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
238 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
239 warnings = append(warnings, fantasy.CallWarning{
240 Type: fantasy.CallWarningTypeOther,
241 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
242 })
243 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
244 }
245 }
246
247 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
248
249 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
250 if len(content) > 0 && content[0].Role == genai.RoleUser {
251 systemParts := []string{}
252 for _, sp := range systemInstructions.Parts {
253 systemParts = append(systemParts, sp.Text)
254 }
255 systemMsg := strings.Join(systemParts, "\n")
256 content[0].Parts = append([]*genai.Part{
257 {
258 Text: systemMsg + "\n\n",
259 },
260 }, content[0].Parts...)
261 systemInstructions = nil
262 }
263 }
264
265 config.SystemInstruction = systemInstructions
266
267 if call.MaxOutputTokens != nil {
268 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
269 }
270
271 if call.Temperature != nil {
272 tmp := float32(*call.Temperature)
273 config.Temperature = &tmp
274 }
275 if call.TopK != nil {
276 tmp := float32(*call.TopK)
277 config.TopK = &tmp
278 }
279 if call.TopP != nil {
280 tmp := float32(*call.TopP)
281 config.TopP = &tmp
282 }
283 if call.FrequencyPenalty != nil {
284 tmp := float32(*call.FrequencyPenalty)
285 config.FrequencyPenalty = &tmp
286 }
287 if call.PresencePenalty != nil {
288 tmp := float32(*call.PresencePenalty)
289 config.PresencePenalty = &tmp
290 }
291
292 if providerOptions.ThinkingConfig != nil {
293 config.ThinkingConfig = &genai.ThinkingConfig{}
294 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
295 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
296 }
297 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
298 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
299 config.ThinkingConfig.ThinkingBudget = &tmp
300 }
301 }
302 for _, safetySetting := range providerOptions.SafetySettings {
303 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
304 Category: genai.HarmCategory(safetySetting.Category),
305 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
306 })
307 }
308 if providerOptions.CachedContent != "" {
309 config.CachedContent = providerOptions.CachedContent
310 }
311
312 if len(call.Tools) > 0 {
313 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
314 config.ToolConfig = toolChoice
315 config.Tools = append(config.Tools, &genai.Tool{
316 FunctionDeclarations: tools,
317 })
318 warnings = append(warnings, toolWarnings...)
319 }
320
321 return config, content, warnings, nil
322}
323
324func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
325 var systemInstructions *genai.Content
326 var content []*genai.Content
327 var warnings []fantasy.CallWarning
328
329 finishedSystemBlock := false
330 for _, msg := range prompt {
331 switch msg.Role {
332 case fantasy.MessageRoleSystem:
333 if finishedSystemBlock {
334 // skip multiple system messages that are separated by user/assistant messages
335 // TODO: see if we need to send error here?
336 continue
337 }
338 finishedSystemBlock = true
339
340 var systemMessages []string
341 for _, part := range msg.Content {
342 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
343 if !ok || text.Text == "" {
344 continue
345 }
346 systemMessages = append(systemMessages, text.Text)
347 }
348 if len(systemMessages) > 0 {
349 systemInstructions = &genai.Content{
350 Parts: []*genai.Part{
351 {
352 Text: strings.Join(systemMessages, "\n"),
353 },
354 },
355 }
356 }
357 case fantasy.MessageRoleUser:
358 var parts []*genai.Part
359 for _, part := range msg.Content {
360 switch part.GetType() {
361 case fantasy.ContentTypeText:
362 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
363 if !ok || text.Text == "" {
364 continue
365 }
366 parts = append(parts, &genai.Part{
367 Text: text.Text,
368 })
369 case fantasy.ContentTypeFile:
370 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
371 if !ok {
372 continue
373 }
374 parts = append(parts, &genai.Part{
375 InlineData: &genai.Blob{
376 Data: file.Data,
377 MIMEType: file.MediaType,
378 },
379 })
380 }
381 }
382 if len(parts) > 0 {
383 content = append(content, &genai.Content{
384 Role: genai.RoleUser,
385 Parts: parts,
386 })
387 }
388 case fantasy.MessageRoleAssistant:
389 var parts []*genai.Part
390 var currentReasoningMetadata *ReasoningMetadata
391 for _, part := range msg.Content {
392 switch part.GetType() {
393 case fantasy.ContentTypeReasoning:
394 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
395 if !ok {
396 continue
397 }
398
399 metadata, ok := reasoning.ProviderOptions[Name]
400 if !ok {
401 continue
402 }
403 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
404 if !ok {
405 continue
406 }
407 currentReasoningMetadata = reasoningMetadata
408 case fantasy.ContentTypeText:
409 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
410 if !ok || text.Text == "" {
411 continue
412 }
413 geminiPart := &genai.Part{
414 Text: text.Text,
415 }
416 if currentReasoningMetadata != nil {
417 geminiPart.ThoughtSignature = []byte(currentReasoningMetadata.Signature)
418 currentReasoningMetadata = nil
419 }
420 parts = append(parts, geminiPart)
421 case fantasy.ContentTypeToolCall:
422 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
423 if !ok {
424 continue
425 }
426
427 var result map[string]any
428 err := json.Unmarshal([]byte(toolCall.Input), &result)
429 if err != nil {
430 continue
431 }
432 geminiPart := &genai.Part{
433 FunctionCall: &genai.FunctionCall{
434 ID: toolCall.ToolCallID,
435 Name: toolCall.ToolName,
436 Args: result,
437 },
438 }
439 if currentReasoningMetadata != nil {
440 geminiPart.ThoughtSignature = []byte(currentReasoningMetadata.Signature)
441 currentReasoningMetadata = nil
442 }
443 parts = append(parts, geminiPart)
444 }
445 }
446 if len(parts) > 0 {
447 content = append(content, &genai.Content{
448 Role: genai.RoleModel,
449 Parts: parts,
450 })
451 }
452 case fantasy.MessageRoleTool:
453 var parts []*genai.Part
454 for _, part := range msg.Content {
455 switch part.GetType() {
456 case fantasy.ContentTypeToolResult:
457 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
458 if !ok {
459 continue
460 }
461 var toolCall fantasy.ToolCallPart
462 for _, m := range prompt {
463 if m.Role == fantasy.MessageRoleAssistant {
464 for _, content := range m.Content {
465 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
466 if !ok {
467 continue
468 }
469 if tc.ToolCallID == result.ToolCallID {
470 toolCall = tc
471 break
472 }
473 }
474 }
475 }
476 switch result.Output.GetType() {
477 case fantasy.ToolResultContentTypeText:
478 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
479 if !ok {
480 continue
481 }
482 response := map[string]any{"result": content.Text}
483 parts = append(parts, &genai.Part{
484 FunctionResponse: &genai.FunctionResponse{
485 ID: result.ToolCallID,
486 Response: response,
487 Name: toolCall.ToolName,
488 },
489 })
490
491 case fantasy.ToolResultContentTypeError:
492 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
493 if !ok {
494 continue
495 }
496 response := map[string]any{"result": content.Error.Error()}
497 parts = append(parts, &genai.Part{
498 FunctionResponse: &genai.FunctionResponse{
499 ID: result.ToolCallID,
500 Response: response,
501 Name: toolCall.ToolName,
502 },
503 })
504 }
505 }
506 }
507 if len(parts) > 0 {
508 content = append(content, &genai.Content{
509 Role: genai.RoleUser,
510 Parts: parts,
511 })
512 }
513 default:
514 panic("unsupported message role: " + msg.Role)
515 }
516 }
517 return systemInstructions, content, warnings
518}
519
520// Generate implements fantasy.LanguageModel.
521func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
522 config, contents, warnings, err := g.prepareParams(call)
523 if err != nil {
524 return nil, err
525 }
526
527 lastMessage, history, ok := slice.Pop(contents)
528 if !ok {
529 return nil, errors.New("no messages to send")
530 }
531
532 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
533 if err != nil {
534 return nil, err
535 }
536
537 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
538 if err != nil {
539 return nil, toProviderErr(err)
540 }
541
542 return g.mapResponse(response, warnings)
543}
544
545// Model implements fantasy.LanguageModel.
546func (g *languageModel) Model() string {
547 return g.modelID
548}
549
550// Provider implements fantasy.LanguageModel.
551func (g *languageModel) Provider() string {
552 return g.provider
553}
554
555// Stream implements fantasy.LanguageModel.
556func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
557 config, contents, warnings, err := g.prepareParams(call)
558 if err != nil {
559 return nil, err
560 }
561
562 lastMessage, history, ok := slice.Pop(contents)
563 if !ok {
564 return nil, errors.New("no messages to send")
565 }
566
567 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
568 if err != nil {
569 return nil, err
570 }
571
572 return func(yield func(fantasy.StreamPart) bool) {
573 if len(warnings) > 0 {
574 if !yield(fantasy.StreamPart{
575 Type: fantasy.StreamPartTypeWarnings,
576 Warnings: warnings,
577 }) {
578 return
579 }
580 }
581
582 var currentContent string
583 var toolCalls []fantasy.ToolCallContent
584 var isActiveText bool
585 var isActiveReasoning bool
586 var blockCounter int
587 var currentTextBlockID string
588 var currentReasoningBlockID string
589 var usage *fantasy.Usage
590 var lastFinishReason fantasy.FinishReason
591
592 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
593 if err != nil {
594 yield(fantasy.StreamPart{
595 Type: fantasy.StreamPartTypeError,
596 Error: toProviderErr(err),
597 })
598 return
599 }
600
601 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
602 for _, part := range resp.Candidates[0].Content.Parts {
603 switch {
604 case part.Text != "":
605 delta := part.Text
606 if delta != "" {
607 // Check if this is a reasoning/thought part
608 if part.Thought {
609 // End any active text block before starting reasoning
610 if isActiveText {
611 isActiveText = false
612 if !yield(fantasy.StreamPart{
613 Type: fantasy.StreamPartTypeTextEnd,
614 ID: currentTextBlockID,
615 }) {
616 return
617 }
618 }
619
620 // Start new reasoning block if not already active
621 if !isActiveReasoning {
622 isActiveReasoning = true
623 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
624 blockCounter++
625 if !yield(fantasy.StreamPart{
626 Type: fantasy.StreamPartTypeReasoningStart,
627 ID: currentReasoningBlockID,
628 }) {
629 return
630 }
631 }
632
633 if !yield(fantasy.StreamPart{
634 Type: fantasy.StreamPartTypeReasoningDelta,
635 ID: currentReasoningBlockID,
636 Delta: delta,
637 }) {
638 return
639 }
640 } else {
641 // Start new text block if not already active
642 if !isActiveText {
643 isActiveText = true
644 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
645 blockCounter++
646 if !yield(fantasy.StreamPart{
647 Type: fantasy.StreamPartTypeTextStart,
648 ID: currentTextBlockID,
649 }) {
650 return
651 }
652 }
653 // End any active reasoning block before starting text
654 if isActiveReasoning {
655 isActiveReasoning = false
656 metadata := &ReasoningMetadata{
657 Signature: string(part.ThoughtSignature),
658 }
659 if !yield(fantasy.StreamPart{
660 Type: fantasy.StreamPartTypeReasoningEnd,
661 ID: currentReasoningBlockID,
662 ProviderMetadata: fantasy.ProviderMetadata{
663 Name: metadata,
664 },
665 }) {
666 return
667 }
668 } else if part.ThoughtSignature != nil {
669 metadata := &ReasoningMetadata{
670 Signature: string(part.ThoughtSignature),
671 }
672
673 if !yield(fantasy.StreamPart{
674 Type: fantasy.StreamPartTypeReasoningStart,
675 ID: currentReasoningBlockID,
676 }) {
677 return
678 }
679 if !yield(fantasy.StreamPart{
680 Type: fantasy.StreamPartTypeReasoningEnd,
681 ID: currentReasoningBlockID,
682 ProviderMetadata: fantasy.ProviderMetadata{
683 Name: metadata,
684 },
685 }) {
686 return
687 }
688 }
689
690 if !yield(fantasy.StreamPart{
691 Type: fantasy.StreamPartTypeTextDelta,
692 ID: currentTextBlockID,
693 Delta: delta,
694 }) {
695 return
696 }
697 currentContent += delta
698 }
699 }
700 case part.FunctionCall != nil:
701 // End any active text or reasoning blocks
702 if isActiveText {
703 isActiveText = false
704 if !yield(fantasy.StreamPart{
705 Type: fantasy.StreamPartTypeTextEnd,
706 ID: currentTextBlockID,
707 }) {
708 return
709 }
710 }
711 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
712 // End any active reasoning block before starting text
713 if isActiveReasoning {
714 isActiveReasoning = false
715 metadata := &ReasoningMetadata{
716 Signature: string(part.ThoughtSignature),
717 ToolID: toolCallID,
718 }
719 if !yield(fantasy.StreamPart{
720 Type: fantasy.StreamPartTypeReasoningEnd,
721 ID: currentReasoningBlockID,
722 ProviderMetadata: fantasy.ProviderMetadata{
723 Name: metadata,
724 },
725 }) {
726 return
727 }
728 } else if part.ThoughtSignature != nil {
729 metadata := &ReasoningMetadata{
730 Signature: string(part.ThoughtSignature),
731 ToolID: toolCallID,
732 }
733
734 if !yield(fantasy.StreamPart{
735 Type: fantasy.StreamPartTypeReasoningStart,
736 ID: currentReasoningBlockID,
737 }) {
738 return
739 }
740 if !yield(fantasy.StreamPart{
741 Type: fantasy.StreamPartTypeReasoningEnd,
742 ID: currentReasoningBlockID,
743 ProviderMetadata: fantasy.ProviderMetadata{
744 Name: metadata,
745 },
746 }) {
747 return
748 }
749 }
750 args, err := json.Marshal(part.FunctionCall.Args)
751 if err != nil {
752 yield(fantasy.StreamPart{
753 Type: fantasy.StreamPartTypeError,
754 Error: err,
755 })
756 return
757 }
758
759 if !yield(fantasy.StreamPart{
760 Type: fantasy.StreamPartTypeToolInputStart,
761 ID: toolCallID,
762 ToolCallName: part.FunctionCall.Name,
763 }) {
764 return
765 }
766
767 if !yield(fantasy.StreamPart{
768 Type: fantasy.StreamPartTypeToolInputDelta,
769 ID: toolCallID,
770 Delta: string(args),
771 }) {
772 return
773 }
774
775 if !yield(fantasy.StreamPart{
776 Type: fantasy.StreamPartTypeToolInputEnd,
777 ID: toolCallID,
778 }) {
779 return
780 }
781
782 if !yield(fantasy.StreamPart{
783 Type: fantasy.StreamPartTypeToolCall,
784 ID: toolCallID,
785 ToolCallName: part.FunctionCall.Name,
786 ToolCallInput: string(args),
787 ProviderExecuted: false,
788 }) {
789 return
790 }
791
792 toolCalls = append(toolCalls, fantasy.ToolCallContent{
793 ToolCallID: toolCallID,
794 ToolName: part.FunctionCall.Name,
795 Input: string(args),
796 ProviderExecuted: false,
797 })
798 }
799 }
800 }
801
802 // we need to make sure that there is actual tokendata
803 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
804 currentUsage := mapUsage(resp.UsageMetadata)
805 // if first usage chunk
806 if usage == nil {
807 usage = ¤tUsage
808 } else {
809 usage.OutputTokens += currentUsage.OutputTokens
810 usage.ReasoningTokens += currentUsage.ReasoningTokens
811 usage.CacheReadTokens += currentUsage.CacheReadTokens
812 }
813 }
814
815 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
816 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
817 }
818 }
819
820 // Close any open blocks before finishing
821 if isActiveText {
822 if !yield(fantasy.StreamPart{
823 Type: fantasy.StreamPartTypeTextEnd,
824 ID: currentTextBlockID,
825 }) {
826 return
827 }
828 }
829 if isActiveReasoning {
830 if !yield(fantasy.StreamPart{
831 Type: fantasy.StreamPartTypeReasoningEnd,
832 ID: currentReasoningBlockID,
833 }) {
834 return
835 }
836 }
837
838 finishReason := lastFinishReason
839 if len(toolCalls) > 0 {
840 finishReason = fantasy.FinishReasonToolCalls
841 } else if finishReason == "" {
842 finishReason = fantasy.FinishReasonStop
843 }
844
845 yield(fantasy.StreamPart{
846 Type: fantasy.StreamPartTypeFinish,
847 Usage: *usage,
848 FinishReason: finishReason,
849 })
850 }, nil
851}
852
853// GenerateObject implements fantasy.LanguageModel.
854func (g *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
855 switch g.objectMode {
856 case fantasy.ObjectModeText:
857 return object.GenerateWithText(ctx, g, call)
858 case fantasy.ObjectModeTool:
859 return object.GenerateWithTool(ctx, g, call)
860 default:
861 return g.generateObjectWithJSONMode(ctx, call)
862 }
863}
864
865// StreamObject implements fantasy.LanguageModel.
866func (g *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
867 switch g.objectMode {
868 case fantasy.ObjectModeTool:
869 return object.StreamWithTool(ctx, g, call)
870 case fantasy.ObjectModeText:
871 return object.StreamWithText(ctx, g, call)
872 default:
873 return g.streamObjectWithJSONMode(ctx, call)
874 }
875}
876
877func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
878 // Convert our Schema to Google's JSON Schema format
879 jsonSchemaMap := schema.ToMap(call.Schema)
880
881 // Build request using prepareParams
882 fantasyCall := fantasy.Call{
883 Prompt: call.Prompt,
884 MaxOutputTokens: call.MaxOutputTokens,
885 Temperature: call.Temperature,
886 TopP: call.TopP,
887 TopK: call.TopK,
888 PresencePenalty: call.PresencePenalty,
889 FrequencyPenalty: call.FrequencyPenalty,
890 ProviderOptions: call.ProviderOptions,
891 }
892
893 config, contents, warnings, err := g.prepareParams(fantasyCall)
894 if err != nil {
895 return nil, err
896 }
897
898 // Set ResponseMIMEType and ResponseJsonSchema for structured output
899 config.ResponseMIMEType = "application/json"
900 config.ResponseJsonSchema = jsonSchemaMap
901
902 lastMessage, history, ok := slice.Pop(contents)
903 if !ok {
904 return nil, errors.New("no messages to send")
905 }
906
907 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
908 if err != nil {
909 return nil, err
910 }
911
912 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
913 if err != nil {
914 return nil, toProviderErr(err)
915 }
916
917 mappedResponse, err := g.mapResponse(response, warnings)
918 if err != nil {
919 return nil, err
920 }
921
922 jsonText := mappedResponse.Content.Text()
923 if jsonText == "" {
924 return nil, &fantasy.NoObjectGeneratedError{
925 RawText: "",
926 ParseError: fmt.Errorf("no text content in response"),
927 Usage: mappedResponse.Usage,
928 FinishReason: mappedResponse.FinishReason,
929 }
930 }
931
932 // Parse and validate
933 var obj any
934 if call.RepairText != nil {
935 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
936 } else {
937 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
938 }
939
940 if err != nil {
941 // Add usage info to error
942 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
943 nogErr.Usage = mappedResponse.Usage
944 nogErr.FinishReason = mappedResponse.FinishReason
945 }
946 return nil, err
947 }
948
949 return &fantasy.ObjectResponse{
950 Object: obj,
951 RawText: jsonText,
952 Usage: mappedResponse.Usage,
953 FinishReason: mappedResponse.FinishReason,
954 Warnings: warnings,
955 ProviderMetadata: mappedResponse.ProviderMetadata,
956 }, nil
957}
958
959func (g *languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
960 // Convert our Schema to Google's JSON Schema format
961 jsonSchemaMap := schema.ToMap(call.Schema)
962
963 // Build request using prepareParams
964 fantasyCall := fantasy.Call{
965 Prompt: call.Prompt,
966 MaxOutputTokens: call.MaxOutputTokens,
967 Temperature: call.Temperature,
968 TopP: call.TopP,
969 TopK: call.TopK,
970 PresencePenalty: call.PresencePenalty,
971 FrequencyPenalty: call.FrequencyPenalty,
972 ProviderOptions: call.ProviderOptions,
973 }
974
975 config, contents, warnings, err := g.prepareParams(fantasyCall)
976 if err != nil {
977 return nil, err
978 }
979
980 // Set ResponseMIMEType and ResponseJsonSchema for structured output
981 config.ResponseMIMEType = "application/json"
982 config.ResponseJsonSchema = jsonSchemaMap
983
984 lastMessage, history, ok := slice.Pop(contents)
985 if !ok {
986 return nil, errors.New("no messages to send")
987 }
988
989 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
990 if err != nil {
991 return nil, err
992 }
993
994 return func(yield func(fantasy.ObjectStreamPart) bool) {
995 if len(warnings) > 0 {
996 if !yield(fantasy.ObjectStreamPart{
997 Type: fantasy.ObjectStreamPartTypeObject,
998 Warnings: warnings,
999 }) {
1000 return
1001 }
1002 }
1003
1004 var accumulated string
1005 var lastParsedObject any
1006 var usage *fantasy.Usage
1007 var lastFinishReason fantasy.FinishReason
1008 var streamErr error
1009
1010 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
1011 if err != nil {
1012 streamErr = toProviderErr(err)
1013 yield(fantasy.ObjectStreamPart{
1014 Type: fantasy.ObjectStreamPartTypeError,
1015 Error: streamErr,
1016 })
1017 return
1018 }
1019
1020 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
1021 for _, part := range resp.Candidates[0].Content.Parts {
1022 if part.Text != "" && !part.Thought {
1023 accumulated += part.Text
1024
1025 // Try to parse the accumulated text
1026 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
1027
1028 // If we successfully parsed, validate and emit
1029 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
1030 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
1031 // Only emit if object is different from last
1032 if !reflect.DeepEqual(obj, lastParsedObject) {
1033 if !yield(fantasy.ObjectStreamPart{
1034 Type: fantasy.ObjectStreamPartTypeObject,
1035 Object: obj,
1036 }) {
1037 return
1038 }
1039 lastParsedObject = obj
1040 }
1041 }
1042 }
1043
1044 // If parsing failed and we have a repair function, try it
1045 if state == schema.ParseStateFailed && call.RepairText != nil {
1046 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
1047 if repairErr == nil {
1048 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
1049 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
1050 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
1051 if !reflect.DeepEqual(obj2, lastParsedObject) {
1052 if !yield(fantasy.ObjectStreamPart{
1053 Type: fantasy.ObjectStreamPartTypeObject,
1054 Object: obj2,
1055 }) {
1056 return
1057 }
1058 lastParsedObject = obj2
1059 }
1060 }
1061 }
1062 }
1063 }
1064 }
1065 }
1066
1067 // we need to make sure that there is actual tokendata
1068 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
1069 currentUsage := mapUsage(resp.UsageMetadata)
1070 if usage == nil {
1071 usage = ¤tUsage
1072 } else {
1073 usage.OutputTokens += currentUsage.OutputTokens
1074 usage.ReasoningTokens += currentUsage.ReasoningTokens
1075 usage.CacheReadTokens += currentUsage.CacheReadTokens
1076 }
1077 }
1078
1079 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
1080 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
1081 }
1082 }
1083
1084 // Final validation and emit
1085 if streamErr == nil && lastParsedObject != nil {
1086 finishReason := lastFinishReason
1087 if finishReason == "" {
1088 finishReason = fantasy.FinishReasonStop
1089 }
1090
1091 yield(fantasy.ObjectStreamPart{
1092 Type: fantasy.ObjectStreamPartTypeFinish,
1093 Usage: *usage,
1094 FinishReason: finishReason,
1095 })
1096 } else if streamErr == nil && lastParsedObject == nil {
1097 // No object was generated
1098 finalUsage := fantasy.Usage{}
1099 if usage != nil {
1100 finalUsage = *usage
1101 }
1102 yield(fantasy.ObjectStreamPart{
1103 Type: fantasy.ObjectStreamPartTypeError,
1104 Error: &fantasy.NoObjectGeneratedError{
1105 RawText: accumulated,
1106 ParseError: fmt.Errorf("no valid object generated in stream"),
1107 Usage: finalUsage,
1108 FinishReason: lastFinishReason,
1109 },
1110 })
1111 }
1112 }, nil
1113}
1114
1115func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
1116 for _, tool := range tools {
1117 if tool.GetType() == fantasy.ToolTypeFunction {
1118 ft, ok := tool.(fantasy.FunctionTool)
1119 if !ok {
1120 continue
1121 }
1122
1123 required := []string{}
1124 var properties map[string]any
1125 if props, ok := ft.InputSchema["properties"]; ok {
1126 properties, _ = props.(map[string]any)
1127 }
1128 if req, ok := ft.InputSchema["required"]; ok {
1129 if reqArr, ok := req.([]string); ok {
1130 required = reqArr
1131 }
1132 }
1133 declaration := &genai.FunctionDeclaration{
1134 Name: ft.Name,
1135 Description: ft.Description,
1136 Parameters: &genai.Schema{
1137 Type: genai.TypeObject,
1138 Properties: convertSchemaProperties(properties),
1139 Required: required,
1140 },
1141 }
1142 googleTools = append(googleTools, declaration)
1143 continue
1144 }
1145 // TODO: handle provider tool calls
1146 warnings = append(warnings, fantasy.CallWarning{
1147 Type: fantasy.CallWarningTypeUnsupportedTool,
1148 Tool: tool,
1149 Message: "tool is not supported",
1150 })
1151 }
1152 if toolChoice == nil {
1153 return googleTools, googleToolChoice, warnings
1154 }
1155 switch *toolChoice {
1156 case fantasy.ToolChoiceAuto:
1157 googleToolChoice = &genai.ToolConfig{
1158 FunctionCallingConfig: &genai.FunctionCallingConfig{
1159 Mode: genai.FunctionCallingConfigModeAuto,
1160 },
1161 }
1162 case fantasy.ToolChoiceRequired:
1163 googleToolChoice = &genai.ToolConfig{
1164 FunctionCallingConfig: &genai.FunctionCallingConfig{
1165 Mode: genai.FunctionCallingConfigModeAny,
1166 },
1167 }
1168 case fantasy.ToolChoiceNone:
1169 googleToolChoice = &genai.ToolConfig{
1170 FunctionCallingConfig: &genai.FunctionCallingConfig{
1171 Mode: genai.FunctionCallingConfigModeNone,
1172 },
1173 }
1174 default:
1175 googleToolChoice = &genai.ToolConfig{
1176 FunctionCallingConfig: &genai.FunctionCallingConfig{
1177 Mode: genai.FunctionCallingConfigModeAny,
1178 AllowedFunctionNames: []string{
1179 string(*toolChoice),
1180 },
1181 },
1182 }
1183 }
1184 return googleTools, googleToolChoice, warnings
1185}
1186
1187func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
1188 properties := make(map[string]*genai.Schema)
1189
1190 for name, param := range parameters {
1191 properties[name] = convertToSchema(param)
1192 }
1193
1194 return properties
1195}
1196
1197func convertToSchema(param any) *genai.Schema {
1198 schema := &genai.Schema{Type: genai.TypeString}
1199
1200 paramMap, ok := param.(map[string]any)
1201 if !ok {
1202 return schema
1203 }
1204
1205 if desc, ok := paramMap["description"].(string); ok {
1206 schema.Description = desc
1207 }
1208
1209 typeVal, hasType := paramMap["type"]
1210 if !hasType {
1211 return schema
1212 }
1213
1214 typeStr, ok := typeVal.(string)
1215 if !ok {
1216 return schema
1217 }
1218
1219 schema.Type = mapJSONTypeToGoogle(typeStr)
1220
1221 switch typeStr {
1222 case "array":
1223 schema.Items = processArrayItems(paramMap)
1224 case "object":
1225 if props, ok := paramMap["properties"].(map[string]any); ok {
1226 schema.Properties = convertSchemaProperties(props)
1227 }
1228 }
1229
1230 return schema
1231}
1232
1233func processArrayItems(paramMap map[string]any) *genai.Schema {
1234 items, ok := paramMap["items"].(map[string]any)
1235 if !ok {
1236 return nil
1237 }
1238
1239 return convertToSchema(items)
1240}
1241
1242func mapJSONTypeToGoogle(jsonType string) genai.Type {
1243 switch jsonType {
1244 case "string":
1245 return genai.TypeString
1246 case "number":
1247 return genai.TypeNumber
1248 case "integer":
1249 return genai.TypeInteger
1250 case "boolean":
1251 return genai.TypeBoolean
1252 case "array":
1253 return genai.TypeArray
1254 case "object":
1255 return genai.TypeObject
1256 default:
1257 return genai.TypeString // Default to string for unknown types
1258 }
1259}
1260
1261func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
1262 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
1263 return nil, errors.New("no response from model")
1264 }
1265
1266 var (
1267 content []fantasy.Content
1268 finishReason fantasy.FinishReason
1269 hasToolCalls bool
1270 candidate = response.Candidates[0]
1271 )
1272
1273 for _, part := range candidate.Content.Parts {
1274 switch {
1275 case part.Text != "":
1276 if part.Thought {
1277 reasoningContent := fantasy.ReasoningContent{Text: part.Text}
1278 if part.ThoughtSignature != nil {
1279 metadata := &ReasoningMetadata{
1280 Signature: string(part.ThoughtSignature),
1281 }
1282 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1283 Name: metadata,
1284 }
1285 }
1286 content = append(content, reasoningContent)
1287 } else {
1288 foundReasoning := false
1289 if part.ThoughtSignature != nil {
1290 metadata := &ReasoningMetadata{
1291 Signature: string(part.ThoughtSignature),
1292 }
1293 // find the last reasoning content and add the signature
1294 for i := len(content) - 1; i >= 0; i-- {
1295 c := content[i]
1296 if c.GetType() == fantasy.ContentTypeReasoning {
1297 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
1298 if !ok {
1299 continue
1300 }
1301 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1302 Name: metadata,
1303 }
1304 content[i] = reasoningContent
1305 foundReasoning = true
1306 break
1307 }
1308 }
1309 if !foundReasoning {
1310 content = append(content, fantasy.ReasoningContent{
1311 ProviderMetadata: fantasy.ProviderMetadata{
1312 Name: metadata,
1313 },
1314 })
1315 }
1316 }
1317 content = append(content, fantasy.TextContent{Text: part.Text})
1318 }
1319 case part.FunctionCall != nil:
1320 input, err := json.Marshal(part.FunctionCall.Args)
1321 if err != nil {
1322 return nil, err
1323 }
1324 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
1325 foundReasoning := false
1326 if part.ThoughtSignature != nil {
1327 metadata := &ReasoningMetadata{
1328 Signature: string(part.ThoughtSignature),
1329 ToolID: toolCallID,
1330 }
1331 // find the last reasoning content and add the signature
1332 for i := len(content) - 1; i >= 0; i-- {
1333 c := content[i]
1334 if c.GetType() == fantasy.ContentTypeReasoning {
1335 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
1336 if !ok {
1337 continue
1338 }
1339 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1340 Name: metadata,
1341 }
1342 content[i] = reasoningContent
1343 foundReasoning = true
1344 break
1345 }
1346 }
1347 if !foundReasoning {
1348 content = append(content, fantasy.ReasoningContent{
1349 ProviderMetadata: fantasy.ProviderMetadata{
1350 Name: metadata,
1351 },
1352 })
1353 }
1354 }
1355 content = append(content, fantasy.ToolCallContent{
1356 ToolCallID: toolCallID,
1357 ToolName: part.FunctionCall.Name,
1358 Input: string(input),
1359 ProviderExecuted: false,
1360 })
1361 hasToolCalls = true
1362 default:
1363 // Silently skip unknown part types instead of erroring
1364 // This allows for forward compatibility with new part types
1365 }
1366 }
1367
1368 if hasToolCalls {
1369 finishReason = fantasy.FinishReasonToolCalls
1370 } else {
1371 finishReason = mapFinishReason(candidate.FinishReason)
1372 }
1373
1374 return &fantasy.Response{
1375 Content: content,
1376 Usage: mapUsage(response.UsageMetadata),
1377 FinishReason: finishReason,
1378 Warnings: warnings,
1379 }, nil
1380}
1381
1382// GetReasoningMetadata extracts reasoning metadata from provider options for google models.
1383func GetReasoningMetadata(providerOptions fantasy.ProviderOptions) *ReasoningMetadata {
1384 if googleOptions, ok := providerOptions[Name]; ok {
1385 if reasoning, ok := googleOptions.(*ReasoningMetadata); ok {
1386 return reasoning
1387 }
1388 }
1389 return nil
1390}
1391
1392func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
1393 switch reason {
1394 case genai.FinishReasonStop:
1395 return fantasy.FinishReasonStop
1396 case genai.FinishReasonMaxTokens:
1397 return fantasy.FinishReasonLength
1398 case genai.FinishReasonSafety,
1399 genai.FinishReasonBlocklist,
1400 genai.FinishReasonProhibitedContent,
1401 genai.FinishReasonSPII,
1402 genai.FinishReasonImageSafety:
1403 return fantasy.FinishReasonContentFilter
1404 case genai.FinishReasonRecitation,
1405 genai.FinishReasonLanguage,
1406 genai.FinishReasonMalformedFunctionCall:
1407 return fantasy.FinishReasonError
1408 case genai.FinishReasonOther:
1409 return fantasy.FinishReasonOther
1410 default:
1411 return fantasy.FinishReasonUnknown
1412 }
1413}
1414
1415func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1416 return fantasy.Usage{
1417 InputTokens: int64(usage.PromptTokenCount),
1418 OutputTokens: int64(usage.CandidatesTokenCount),
1419 TotalTokens: int64(usage.TotalTokenCount),
1420 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1421 CacheCreationTokens: 0,
1422 CacheReadTokens: int64(usage.CachedContentTokenCount),
1423 }
1424}