1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "reflect"
12 "strings"
13
14 "charm.land/fantasy"
15 "charm.land/fantasy/object"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/internal/httpheaders"
18 "charm.land/fantasy/schema"
19 "cloud.google.com/go/auth"
20 "github.com/charmbracelet/x/exp/slice"
21 "github.com/google/uuid"
22 "google.golang.org/genai"
23)
24
25// Name is the name of the Google provider.
26const Name = "google"
27
28type provider struct {
29 options options
30}
31
32// ToolCallIDFunc defines a function that generates a tool call ID.
33type ToolCallIDFunc = func() string
34
35type options struct {
36 apiKey string
37 name string
38 baseURL string
39 headers map[string]string
40 userAgent string
41 client *http.Client
42 backend genai.Backend
43 project string
44 location string
45 skipAuth bool
46 toolCallIDFunc ToolCallIDFunc
47 objectMode fantasy.ObjectMode
48}
49
50// Option defines a function that configures Google provider options.
51type Option = func(*options)
52
53// New creates a new Google provider with the given options.
54func New(opts ...Option) (fantasy.Provider, error) {
55 options := options{
56 headers: map[string]string{},
57 toolCallIDFunc: func() string {
58 return uuid.NewString()
59 },
60 }
61 for _, o := range opts {
62 o(&options)
63 }
64
65 options.name = cmp.Or(options.name, Name)
66
67 return &provider{
68 options: options,
69 }, nil
70}
71
72// WithBaseURL sets the base URL for the Google provider.
73func WithBaseURL(baseURL string) Option {
74 return func(o *options) {
75 o.baseURL = baseURL
76 }
77}
78
79// WithGeminiAPIKey sets the Gemini API key for the Google provider.
80func WithGeminiAPIKey(apiKey string) Option {
81 return func(o *options) {
82 o.backend = genai.BackendGeminiAPI
83 o.apiKey = apiKey
84 o.project = ""
85 o.location = ""
86 }
87}
88
89// WithVertex configures the Google provider to use Vertex AI.
90func WithVertex(project, location string) Option {
91 if project == "" || location == "" {
92 panic("project and location must be provided")
93 }
94 return func(o *options) {
95 o.backend = genai.BackendVertexAI
96 o.apiKey = ""
97 o.project = project
98 o.location = location
99 }
100}
101
102// WithSkipAuth configures whether to skip authentication for the Google provider.
103func WithSkipAuth(skipAuth bool) Option {
104 return func(o *options) {
105 o.skipAuth = skipAuth
106 }
107}
108
109// WithName sets the name for the Google provider.
110func WithName(name string) Option {
111 return func(o *options) {
112 o.name = name
113 }
114}
115
116// WithHeaders sets the headers for the Google provider.
117func WithHeaders(headers map[string]string) Option {
118 return func(o *options) {
119 maps.Copy(o.headers, headers)
120 }
121}
122
123// WithHTTPClient sets the HTTP client for the Google provider.
124func WithHTTPClient(client *http.Client) Option {
125 return func(o *options) {
126 o.client = client
127 }
128}
129
130// WithToolCallIDFunc sets the function that generates a tool call ID.
131func WithToolCallIDFunc(f ToolCallIDFunc) Option {
132 return func(o *options) {
133 o.toolCallIDFunc = f
134 }
135}
136
137// WithUserAgent sets an explicit User-Agent header, overriding the default and any
138// value set via WithHeaders.
139func WithUserAgent(ua string) Option {
140 return func(o *options) {
141 o.userAgent = ua
142 }
143}
144
145// WithObjectMode sets the object generation mode for the Google provider.
146func WithObjectMode(om fantasy.ObjectMode) Option {
147 return func(o *options) {
148 o.objectMode = om
149 }
150}
151
152func (*provider) Name() string {
153 return Name
154}
155
156type languageModel struct {
157 provider string
158 modelID string
159 client *genai.Client
160 providerOptions options
161 objectMode fantasy.ObjectMode
162}
163
164// LanguageModel implements fantasy.Provider.
165func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
166 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
167 anthropicOpts := []anthropic.Option{
168 anthropic.WithVertex(a.options.project, a.options.location),
169 anthropic.WithHTTPClient(a.options.client),
170 anthropic.WithSkipAuth(a.options.skipAuth),
171 }
172 if a.options.userAgent != "" {
173 anthropicOpts = append(anthropicOpts, anthropic.WithUserAgent(a.options.userAgent))
174 }
175 p, err := anthropic.New(anthropicOpts...)
176 if err != nil {
177 return nil, err
178 }
179 return p.LanguageModel(ctx, modelID)
180 }
181
182 cc := &genai.ClientConfig{
183 HTTPClient: wrapHTTPClient(a.options.client),
184 Backend: a.options.backend,
185 APIKey: a.options.apiKey,
186 Project: a.options.project,
187 Location: a.options.location,
188 }
189 if a.options.skipAuth {
190 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
191 } else if cc.Backend == genai.BackendVertexAI {
192 if err := cc.UseDefaultCredentials(); err != nil {
193 return nil, err
194 }
195 }
196
197 defaultUA := httpheaders.DefaultUserAgent(fantasy.Version)
198 resolved := httpheaders.ResolveHeaders(a.options.headers, a.options.userAgent, defaultUA)
199
200 headers := http.Header{}
201 for k, v := range resolved {
202 headers.Set(k, v)
203 }
204 cc.HTTPOptions = genai.HTTPOptions{
205 BaseURL: a.options.baseURL,
206 Headers: headers,
207 }
208 client, err := genai.NewClient(ctx, cc)
209 if err != nil {
210 return nil, err
211 }
212
213 objectMode := a.options.objectMode
214 if objectMode == "" {
215 objectMode = fantasy.ObjectModeAuto
216 }
217
218 return &languageModel{
219 modelID: modelID,
220 provider: a.options.name,
221 providerOptions: a.options,
222 client: client,
223 objectMode: objectMode,
224 }, nil
225}
226
227func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
228 config := &genai.GenerateContentConfig{}
229
230 providerOptions := &ProviderOptions{}
231 if v, ok := call.ProviderOptions[Name]; ok {
232 providerOptions, ok = v.(*ProviderOptions)
233 if !ok {
234 return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"}
235 }
236 }
237
238 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
239
240 if providerOptions.ThinkingConfig != nil {
241 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
242 *providerOptions.ThinkingConfig.IncludeThoughts &&
243 strings.HasPrefix(g.provider, "google.vertex.") {
244 warnings = append(warnings, fantasy.CallWarning{
245 Type: fantasy.CallWarningTypeOther,
246 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
247 "and might not be supported or could behave unexpectedly with the current Google provider " +
248 fmt.Sprintf("(%s)", g.provider),
249 })
250 }
251
252 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
253 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
254 warnings = append(warnings, fantasy.CallWarning{
255 Type: fantasy.CallWarningTypeOther,
256 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
257 })
258 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
259 }
260
261 if providerOptions.ThinkingConfig.ThinkingLevel != nil &&
262 providerOptions.ThinkingConfig.ThinkingBudget != nil {
263 return nil, nil, nil, &fantasy.Error{
264 Title: "invalid argument",
265 Message: "thinking_level and thinking_budget are mutually exclusive",
266 }
267 }
268 }
269
270 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
271
272 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
273 if len(content) > 0 && content[0].Role == genai.RoleUser {
274 systemParts := []string{}
275 for _, sp := range systemInstructions.Parts {
276 systemParts = append(systemParts, sp.Text)
277 }
278 systemMsg := strings.Join(systemParts, "\n")
279 content[0].Parts = append([]*genai.Part{
280 {
281 Text: systemMsg + "\n\n",
282 },
283 }, content[0].Parts...)
284 systemInstructions = nil
285 }
286 }
287
288 config.SystemInstruction = systemInstructions
289
290 if call.MaxOutputTokens != nil {
291 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
292 }
293
294 if call.Temperature != nil {
295 tmp := float32(*call.Temperature)
296 config.Temperature = &tmp
297 }
298 if call.TopK != nil {
299 tmp := float32(*call.TopK)
300 config.TopK = &tmp
301 }
302 if call.TopP != nil {
303 tmp := float32(*call.TopP)
304 config.TopP = &tmp
305 }
306 if call.FrequencyPenalty != nil {
307 tmp := float32(*call.FrequencyPenalty)
308 config.FrequencyPenalty = &tmp
309 }
310 if call.PresencePenalty != nil {
311 tmp := float32(*call.PresencePenalty)
312 config.PresencePenalty = &tmp
313 }
314
315 if providerOptions.ThinkingConfig != nil {
316 config.ThinkingConfig = &genai.ThinkingConfig{}
317 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
318 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
319 }
320 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
321 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
322 config.ThinkingConfig.ThinkingBudget = &tmp
323 }
324 if providerOptions.ThinkingConfig.ThinkingLevel != nil {
325 config.ThinkingConfig.ThinkingLevel = genai.ThinkingLevel(*providerOptions.ThinkingConfig.ThinkingLevel)
326 }
327 }
328 for _, safetySetting := range providerOptions.SafetySettings {
329 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
330 Category: genai.HarmCategory(safetySetting.Category),
331 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
332 })
333 }
334 if providerOptions.CachedContent != "" {
335 config.CachedContent = providerOptions.CachedContent
336 }
337
338 if len(call.Tools) > 0 {
339 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
340 config.ToolConfig = toolChoice
341 config.Tools = append(config.Tools, &genai.Tool{
342 FunctionDeclarations: tools,
343 })
344 warnings = append(warnings, toolWarnings...)
345 }
346
347 return config, content, warnings, nil
348}
349
350func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
351 var systemInstructions *genai.Content
352 var content []*genai.Content
353 var warnings []fantasy.CallWarning
354
355 finishedSystemBlock := false
356 for _, msg := range prompt {
357 switch msg.Role {
358 case fantasy.MessageRoleSystem:
359 if finishedSystemBlock {
360 // skip multiple system messages that are separated by user/assistant messages
361 // TODO: see if we need to send error here?
362 continue
363 }
364 finishedSystemBlock = true
365
366 var systemMessages []string
367 for _, part := range msg.Content {
368 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
369 if !ok || text.Text == "" {
370 continue
371 }
372 systemMessages = append(systemMessages, text.Text)
373 }
374 if len(systemMessages) > 0 {
375 systemInstructions = &genai.Content{
376 Parts: []*genai.Part{
377 {
378 Text: strings.Join(systemMessages, "\n"),
379 },
380 },
381 }
382 }
383 case fantasy.MessageRoleUser:
384 var parts []*genai.Part
385 for _, part := range msg.Content {
386 switch part.GetType() {
387 case fantasy.ContentTypeText:
388 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
389 if !ok || text.Text == "" {
390 continue
391 }
392 parts = append(parts, &genai.Part{
393 Text: text.Text,
394 })
395 case fantasy.ContentTypeFile:
396 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
397 if !ok {
398 continue
399 }
400 parts = append(parts, &genai.Part{
401 InlineData: &genai.Blob{
402 Data: file.Data,
403 MIMEType: file.MediaType,
404 },
405 })
406 }
407 }
408 if len(parts) > 0 {
409 content = append(content, &genai.Content{
410 Role: genai.RoleUser,
411 Parts: parts,
412 })
413 }
414 case fantasy.MessageRoleAssistant:
415 var parts []*genai.Part
416 var currentReasoningMetadata *ReasoningMetadata
417 for _, part := range msg.Content {
418 switch part.GetType() {
419 case fantasy.ContentTypeReasoning:
420 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
421 if !ok {
422 continue
423 }
424
425 metadata, ok := reasoning.ProviderOptions[Name]
426 if !ok {
427 continue
428 }
429 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
430 if !ok {
431 continue
432 }
433 currentReasoningMetadata = reasoningMetadata
434 case fantasy.ContentTypeText:
435 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
436 if !ok || text.Text == "" {
437 continue
438 }
439 geminiPart := &genai.Part{
440 Text: text.Text,
441 }
442 if currentReasoningMetadata != nil {
443 geminiPart.ThoughtSignature = []byte(currentReasoningMetadata.Signature)
444 currentReasoningMetadata = nil
445 }
446 parts = append(parts, geminiPart)
447 case fantasy.ContentTypeToolCall:
448 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
449 if !ok {
450 continue
451 }
452
453 var result map[string]any
454 err := json.Unmarshal([]byte(toolCall.Input), &result)
455 if err != nil {
456 continue
457 }
458 geminiPart := &genai.Part{
459 FunctionCall: &genai.FunctionCall{
460 ID: toolCall.ToolCallID,
461 Name: toolCall.ToolName,
462 Args: result,
463 },
464 }
465 if currentReasoningMetadata != nil {
466 geminiPart.ThoughtSignature = []byte(currentReasoningMetadata.Signature)
467 currentReasoningMetadata = nil
468 }
469 parts = append(parts, geminiPart)
470 }
471 }
472 if len(parts) > 0 {
473 content = append(content, &genai.Content{
474 Role: genai.RoleModel,
475 Parts: parts,
476 })
477 }
478 case fantasy.MessageRoleTool:
479 var parts []*genai.Part
480 for _, part := range msg.Content {
481 switch part.GetType() {
482 case fantasy.ContentTypeToolResult:
483 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
484 if !ok {
485 continue
486 }
487 var toolCall fantasy.ToolCallPart
488 for _, m := range prompt {
489 if m.Role == fantasy.MessageRoleAssistant {
490 for _, content := range m.Content {
491 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
492 if !ok {
493 continue
494 }
495 if tc.ToolCallID == result.ToolCallID {
496 toolCall = tc
497 break
498 }
499 }
500 }
501 }
502 switch result.Output.GetType() {
503 case fantasy.ToolResultContentTypeText:
504 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
505 if !ok {
506 continue
507 }
508 response := map[string]any{"result": content.Text}
509 parts = append(parts, &genai.Part{
510 FunctionResponse: &genai.FunctionResponse{
511 ID: result.ToolCallID,
512 Response: response,
513 Name: toolCall.ToolName,
514 },
515 })
516
517 case fantasy.ToolResultContentTypeError:
518 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
519 if !ok {
520 continue
521 }
522 response := map[string]any{"result": content.Error.Error()}
523 parts = append(parts, &genai.Part{
524 FunctionResponse: &genai.FunctionResponse{
525 ID: result.ToolCallID,
526 Response: response,
527 Name: toolCall.ToolName,
528 },
529 })
530 }
531 }
532 }
533 if len(parts) > 0 {
534 content = append(content, &genai.Content{
535 Role: genai.RoleUser,
536 Parts: parts,
537 })
538 }
539 default:
540 panic("unsupported message role: " + msg.Role)
541 }
542 }
543 return systemInstructions, content, warnings
544}
545
546// Generate implements fantasy.LanguageModel.
547func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
548 ctx = withCallUA(ctx, call)
549 config, contents, warnings, err := g.prepareParams(call)
550 if err != nil {
551 return nil, err
552 }
553
554 lastMessage, history, ok := slice.Pop(contents)
555 if !ok {
556 return nil, errors.New("no messages to send")
557 }
558
559 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
560 if err != nil {
561 return nil, err
562 }
563
564 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
565 if err != nil {
566 return nil, toProviderErr(err)
567 }
568
569 return g.mapResponse(response, warnings)
570}
571
572// Model implements fantasy.LanguageModel.
573func (g *languageModel) Model() string {
574 return g.modelID
575}
576
577// Provider implements fantasy.LanguageModel.
578func (g *languageModel) Provider() string {
579 return g.provider
580}
581
582// Stream implements fantasy.LanguageModel.
583func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
584 ctx = withCallUA(ctx, call)
585 config, contents, warnings, err := g.prepareParams(call)
586 if err != nil {
587 return nil, err
588 }
589
590 lastMessage, history, ok := slice.Pop(contents)
591 if !ok {
592 return nil, errors.New("no messages to send")
593 }
594
595 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
596 if err != nil {
597 return nil, err
598 }
599
600 return func(yield func(fantasy.StreamPart) bool) {
601 if len(warnings) > 0 {
602 if !yield(fantasy.StreamPart{
603 Type: fantasy.StreamPartTypeWarnings,
604 Warnings: warnings,
605 }) {
606 return
607 }
608 }
609
610 var currentContent string
611 var toolCalls []fantasy.ToolCallContent
612 var isActiveText bool
613 var isActiveReasoning bool
614 var blockCounter int
615 var currentTextBlockID string
616 var currentReasoningBlockID string
617 var usage *fantasy.Usage
618 var lastFinishReason fantasy.FinishReason
619
620 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
621 if err != nil {
622 yield(fantasy.StreamPart{
623 Type: fantasy.StreamPartTypeError,
624 Error: toProviderErr(err),
625 })
626 return
627 }
628
629 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
630 for _, part := range resp.Candidates[0].Content.Parts {
631 switch {
632 case part.Text != "":
633 delta := part.Text
634 if delta != "" {
635 // Check if this is a reasoning/thought part
636 if part.Thought {
637 // End any active text block before starting reasoning
638 if isActiveText {
639 isActiveText = false
640 if !yield(fantasy.StreamPart{
641 Type: fantasy.StreamPartTypeTextEnd,
642 ID: currentTextBlockID,
643 }) {
644 return
645 }
646 }
647
648 // Start new reasoning block if not already active
649 if !isActiveReasoning {
650 isActiveReasoning = true
651 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
652 blockCounter++
653 if !yield(fantasy.StreamPart{
654 Type: fantasy.StreamPartTypeReasoningStart,
655 ID: currentReasoningBlockID,
656 }) {
657 return
658 }
659 }
660
661 if !yield(fantasy.StreamPart{
662 Type: fantasy.StreamPartTypeReasoningDelta,
663 ID: currentReasoningBlockID,
664 Delta: delta,
665 }) {
666 return
667 }
668 } else {
669 // Start new text block if not already active
670 if !isActiveText {
671 isActiveText = true
672 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
673 blockCounter++
674 if !yield(fantasy.StreamPart{
675 Type: fantasy.StreamPartTypeTextStart,
676 ID: currentTextBlockID,
677 }) {
678 return
679 }
680 }
681 // End any active reasoning block before starting text
682 if isActiveReasoning {
683 isActiveReasoning = false
684 metadata := &ReasoningMetadata{
685 Signature: string(part.ThoughtSignature),
686 }
687 if !yield(fantasy.StreamPart{
688 Type: fantasy.StreamPartTypeReasoningEnd,
689 ID: currentReasoningBlockID,
690 ProviderMetadata: fantasy.ProviderMetadata{
691 Name: metadata,
692 },
693 }) {
694 return
695 }
696 } else if part.ThoughtSignature != nil {
697 metadata := &ReasoningMetadata{
698 Signature: string(part.ThoughtSignature),
699 }
700
701 if !yield(fantasy.StreamPart{
702 Type: fantasy.StreamPartTypeReasoningStart,
703 ID: currentReasoningBlockID,
704 }) {
705 return
706 }
707 if !yield(fantasy.StreamPart{
708 Type: fantasy.StreamPartTypeReasoningEnd,
709 ID: currentReasoningBlockID,
710 ProviderMetadata: fantasy.ProviderMetadata{
711 Name: metadata,
712 },
713 }) {
714 return
715 }
716 }
717
718 if !yield(fantasy.StreamPart{
719 Type: fantasy.StreamPartTypeTextDelta,
720 ID: currentTextBlockID,
721 Delta: delta,
722 }) {
723 return
724 }
725 currentContent += delta
726 }
727 }
728 case part.FunctionCall != nil:
729 // End any active text or reasoning blocks
730 if isActiveText {
731 isActiveText = false
732 if !yield(fantasy.StreamPart{
733 Type: fantasy.StreamPartTypeTextEnd,
734 ID: currentTextBlockID,
735 }) {
736 return
737 }
738 }
739 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
740 // End any active reasoning block before starting text
741 if isActiveReasoning {
742 isActiveReasoning = false
743 metadata := &ReasoningMetadata{
744 Signature: string(part.ThoughtSignature),
745 ToolID: toolCallID,
746 }
747 if !yield(fantasy.StreamPart{
748 Type: fantasy.StreamPartTypeReasoningEnd,
749 ID: currentReasoningBlockID,
750 ProviderMetadata: fantasy.ProviderMetadata{
751 Name: metadata,
752 },
753 }) {
754 return
755 }
756 } else if part.ThoughtSignature != nil {
757 metadata := &ReasoningMetadata{
758 Signature: string(part.ThoughtSignature),
759 ToolID: toolCallID,
760 }
761
762 if !yield(fantasy.StreamPart{
763 Type: fantasy.StreamPartTypeReasoningStart,
764 ID: currentReasoningBlockID,
765 }) {
766 return
767 }
768 if !yield(fantasy.StreamPart{
769 Type: fantasy.StreamPartTypeReasoningEnd,
770 ID: currentReasoningBlockID,
771 ProviderMetadata: fantasy.ProviderMetadata{
772 Name: metadata,
773 },
774 }) {
775 return
776 }
777 }
778 args, err := json.Marshal(part.FunctionCall.Args)
779 if err != nil {
780 yield(fantasy.StreamPart{
781 Type: fantasy.StreamPartTypeError,
782 Error: err,
783 })
784 return
785 }
786
787 if !yield(fantasy.StreamPart{
788 Type: fantasy.StreamPartTypeToolInputStart,
789 ID: toolCallID,
790 ToolCallName: part.FunctionCall.Name,
791 }) {
792 return
793 }
794
795 if !yield(fantasy.StreamPart{
796 Type: fantasy.StreamPartTypeToolInputDelta,
797 ID: toolCallID,
798 Delta: string(args),
799 }) {
800 return
801 }
802
803 if !yield(fantasy.StreamPart{
804 Type: fantasy.StreamPartTypeToolInputEnd,
805 ID: toolCallID,
806 }) {
807 return
808 }
809
810 if !yield(fantasy.StreamPart{
811 Type: fantasy.StreamPartTypeToolCall,
812 ID: toolCallID,
813 ToolCallName: part.FunctionCall.Name,
814 ToolCallInput: string(args),
815 ProviderExecuted: false,
816 }) {
817 return
818 }
819
820 toolCalls = append(toolCalls, fantasy.ToolCallContent{
821 ToolCallID: toolCallID,
822 ToolName: part.FunctionCall.Name,
823 Input: string(args),
824 ProviderExecuted: false,
825 })
826 }
827 }
828 }
829
830 // we need to make sure that there is actual tokendata
831 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
832 currentUsage := mapUsage(resp.UsageMetadata)
833 // if first usage chunk
834 if usage == nil {
835 usage = ¤tUsage
836 } else {
837 usage.OutputTokens += currentUsage.OutputTokens
838 usage.ReasoningTokens += currentUsage.ReasoningTokens
839 usage.CacheReadTokens += currentUsage.CacheReadTokens
840 }
841 }
842
843 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
844 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
845 }
846 }
847
848 // Close any open blocks before finishing
849 if isActiveText {
850 if !yield(fantasy.StreamPart{
851 Type: fantasy.StreamPartTypeTextEnd,
852 ID: currentTextBlockID,
853 }) {
854 return
855 }
856 }
857 if isActiveReasoning {
858 if !yield(fantasy.StreamPart{
859 Type: fantasy.StreamPartTypeReasoningEnd,
860 ID: currentReasoningBlockID,
861 }) {
862 return
863 }
864 }
865
866 finishReason := lastFinishReason
867 if len(toolCalls) > 0 {
868 finishReason = fantasy.FinishReasonToolCalls
869 } else if finishReason == "" {
870 finishReason = fantasy.FinishReasonStop
871 }
872
873 var finalUsage fantasy.Usage
874 if usage != nil {
875 finalUsage = *usage
876 }
877
878 yield(fantasy.StreamPart{
879 Type: fantasy.StreamPartTypeFinish,
880 Usage: finalUsage,
881 FinishReason: finishReason,
882 })
883 }, nil
884}
885
886// GenerateObject implements fantasy.LanguageModel.
887func (g *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
888 switch g.objectMode {
889 case fantasy.ObjectModeText:
890 return object.GenerateWithText(ctx, g, call)
891 case fantasy.ObjectModeTool:
892 return object.GenerateWithTool(ctx, g, call)
893 default:
894 return g.generateObjectWithJSONMode(ctx, call)
895 }
896}
897
898// StreamObject implements fantasy.LanguageModel.
899func (g *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
900 switch g.objectMode {
901 case fantasy.ObjectModeTool:
902 return object.StreamWithTool(ctx, g, call)
903 case fantasy.ObjectModeText:
904 return object.StreamWithText(ctx, g, call)
905 default:
906 return g.streamObjectWithJSONMode(ctx, call)
907 }
908}
909
910func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
911 ctx = withObjectCallUA(ctx, call)
912 // Convert our Schema to Google's JSON Schema format
913 jsonSchemaMap := schema.ToMap(call.Schema)
914
915 // Build request using prepareParams
916 fantasyCall := fantasy.Call{
917 Prompt: call.Prompt,
918 MaxOutputTokens: call.MaxOutputTokens,
919 Temperature: call.Temperature,
920 TopP: call.TopP,
921 TopK: call.TopK,
922 PresencePenalty: call.PresencePenalty,
923 FrequencyPenalty: call.FrequencyPenalty,
924 ProviderOptions: call.ProviderOptions,
925 }
926
927 config, contents, warnings, err := g.prepareParams(fantasyCall)
928 if err != nil {
929 return nil, err
930 }
931
932 // Set ResponseMIMEType and ResponseJsonSchema for structured output
933 config.ResponseMIMEType = "application/json"
934 config.ResponseJsonSchema = jsonSchemaMap
935
936 lastMessage, history, ok := slice.Pop(contents)
937 if !ok {
938 return nil, errors.New("no messages to send")
939 }
940
941 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
942 if err != nil {
943 return nil, err
944 }
945
946 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
947 if err != nil {
948 return nil, toProviderErr(err)
949 }
950
951 mappedResponse, err := g.mapResponse(response, warnings)
952 if err != nil {
953 return nil, err
954 }
955
956 jsonText := mappedResponse.Content.Text()
957 if jsonText == "" {
958 return nil, &fantasy.NoObjectGeneratedError{
959 RawText: "",
960 ParseError: fmt.Errorf("no text content in response"),
961 Usage: mappedResponse.Usage,
962 FinishReason: mappedResponse.FinishReason,
963 }
964 }
965
966 // Parse and validate
967 var obj any
968 if call.RepairText != nil {
969 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
970 } else {
971 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
972 }
973
974 if err != nil {
975 // Add usage info to error
976 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
977 nogErr.Usage = mappedResponse.Usage
978 nogErr.FinishReason = mappedResponse.FinishReason
979 }
980 return nil, err
981 }
982
983 return &fantasy.ObjectResponse{
984 Object: obj,
985 RawText: jsonText,
986 Usage: mappedResponse.Usage,
987 FinishReason: mappedResponse.FinishReason,
988 Warnings: warnings,
989 ProviderMetadata: mappedResponse.ProviderMetadata,
990 }, nil
991}
992
993func (g *languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
994 ctx = withObjectCallUA(ctx, call)
995 // Convert our Schema to Google's JSON Schema format
996 jsonSchemaMap := schema.ToMap(call.Schema)
997
998 // Build request using prepareParams
999 fantasyCall := fantasy.Call{
1000 Prompt: call.Prompt,
1001 MaxOutputTokens: call.MaxOutputTokens,
1002 Temperature: call.Temperature,
1003 TopP: call.TopP,
1004 TopK: call.TopK,
1005 PresencePenalty: call.PresencePenalty,
1006 FrequencyPenalty: call.FrequencyPenalty,
1007 ProviderOptions: call.ProviderOptions,
1008 }
1009
1010 config, contents, warnings, err := g.prepareParams(fantasyCall)
1011 if err != nil {
1012 return nil, err
1013 }
1014
1015 // Set ResponseMIMEType and ResponseJsonSchema for structured output
1016 config.ResponseMIMEType = "application/json"
1017 config.ResponseJsonSchema = jsonSchemaMap
1018
1019 lastMessage, history, ok := slice.Pop(contents)
1020 if !ok {
1021 return nil, errors.New("no messages to send")
1022 }
1023
1024 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
1025 if err != nil {
1026 return nil, err
1027 }
1028
1029 return func(yield func(fantasy.ObjectStreamPart) bool) {
1030 if len(warnings) > 0 {
1031 if !yield(fantasy.ObjectStreamPart{
1032 Type: fantasy.ObjectStreamPartTypeObject,
1033 Warnings: warnings,
1034 }) {
1035 return
1036 }
1037 }
1038
1039 var accumulated string
1040 var lastParsedObject any
1041 var usage *fantasy.Usage
1042 var lastFinishReason fantasy.FinishReason
1043 var streamErr error
1044
1045 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
1046 if err != nil {
1047 streamErr = toProviderErr(err)
1048 yield(fantasy.ObjectStreamPart{
1049 Type: fantasy.ObjectStreamPartTypeError,
1050 Error: streamErr,
1051 })
1052 return
1053 }
1054
1055 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
1056 for _, part := range resp.Candidates[0].Content.Parts {
1057 if part.Text != "" && !part.Thought {
1058 accumulated += part.Text
1059
1060 // Try to parse the accumulated text
1061 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
1062
1063 // If we successfully parsed, validate and emit
1064 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
1065 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
1066 // Only emit if object is different from last
1067 if !reflect.DeepEqual(obj, lastParsedObject) {
1068 if !yield(fantasy.ObjectStreamPart{
1069 Type: fantasy.ObjectStreamPartTypeObject,
1070 Object: obj,
1071 }) {
1072 return
1073 }
1074 lastParsedObject = obj
1075 }
1076 }
1077 }
1078
1079 // If parsing failed and we have a repair function, try it
1080 if state == schema.ParseStateFailed && call.RepairText != nil {
1081 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
1082 if repairErr == nil {
1083 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
1084 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
1085 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
1086 if !reflect.DeepEqual(obj2, lastParsedObject) {
1087 if !yield(fantasy.ObjectStreamPart{
1088 Type: fantasy.ObjectStreamPartTypeObject,
1089 Object: obj2,
1090 }) {
1091 return
1092 }
1093 lastParsedObject = obj2
1094 }
1095 }
1096 }
1097 }
1098 }
1099 }
1100 }
1101
1102 // we need to make sure that there is actual tokendata
1103 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
1104 currentUsage := mapUsage(resp.UsageMetadata)
1105 if usage == nil {
1106 usage = ¤tUsage
1107 } else {
1108 usage.OutputTokens += currentUsage.OutputTokens
1109 usage.ReasoningTokens += currentUsage.ReasoningTokens
1110 usage.CacheReadTokens += currentUsage.CacheReadTokens
1111 }
1112 }
1113
1114 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
1115 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
1116 }
1117 }
1118
1119 // Final validation and emit
1120 if streamErr == nil && lastParsedObject != nil {
1121 finishReason := cmp.Or(lastFinishReason, fantasy.FinishReasonStop)
1122
1123 var finalUsage fantasy.Usage
1124 if usage != nil {
1125 finalUsage = *usage
1126 }
1127
1128 yield(fantasy.ObjectStreamPart{
1129 Type: fantasy.ObjectStreamPartTypeFinish,
1130 Usage: finalUsage,
1131 FinishReason: finishReason,
1132 })
1133 } else if streamErr == nil && lastParsedObject == nil {
1134 // No object was generated
1135 var finalUsage fantasy.Usage
1136 if usage != nil {
1137 finalUsage = *usage
1138 }
1139 yield(fantasy.ObjectStreamPart{
1140 Type: fantasy.ObjectStreamPartTypeError,
1141 Error: &fantasy.NoObjectGeneratedError{
1142 RawText: accumulated,
1143 ParseError: fmt.Errorf("no valid object generated in stream"),
1144 Usage: finalUsage,
1145 FinishReason: lastFinishReason,
1146 },
1147 })
1148 }
1149 }, nil
1150}
1151
1152func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
1153 for _, tool := range tools {
1154 if tool.GetType() == fantasy.ToolTypeFunction {
1155 ft, ok := tool.(fantasy.FunctionTool)
1156 if !ok {
1157 continue
1158 }
1159
1160 var required []string
1161 var properties map[string]any
1162 if props, ok := ft.InputSchema["properties"]; ok {
1163 properties, _ = props.(map[string]any)
1164 }
1165 if req, ok := ft.InputSchema["required"]; ok {
1166 if reqArr, ok := req.([]string); ok {
1167 required = reqArr
1168 }
1169 }
1170 declaration := &genai.FunctionDeclaration{
1171 Name: ft.Name,
1172 Description: ft.Description,
1173 Parameters: &genai.Schema{
1174 Type: genai.TypeObject,
1175 Properties: convertSchemaProperties(properties),
1176 Required: required,
1177 },
1178 }
1179 googleTools = append(googleTools, declaration)
1180 continue
1181 }
1182 // TODO: handle provider tool calls
1183 warnings = append(warnings, fantasy.CallWarning{
1184 Type: fantasy.CallWarningTypeUnsupportedTool,
1185 Tool: tool,
1186 Message: "tool is not supported",
1187 })
1188 }
1189 if toolChoice == nil {
1190 return googleTools, googleToolChoice, warnings
1191 }
1192 switch *toolChoice {
1193 case fantasy.ToolChoiceAuto:
1194 googleToolChoice = &genai.ToolConfig{
1195 FunctionCallingConfig: &genai.FunctionCallingConfig{
1196 Mode: genai.FunctionCallingConfigModeAuto,
1197 },
1198 }
1199 case fantasy.ToolChoiceRequired:
1200 googleToolChoice = &genai.ToolConfig{
1201 FunctionCallingConfig: &genai.FunctionCallingConfig{
1202 Mode: genai.FunctionCallingConfigModeAny,
1203 },
1204 }
1205 case fantasy.ToolChoiceNone:
1206 googleToolChoice = &genai.ToolConfig{
1207 FunctionCallingConfig: &genai.FunctionCallingConfig{
1208 Mode: genai.FunctionCallingConfigModeNone,
1209 },
1210 }
1211 default:
1212 googleToolChoice = &genai.ToolConfig{
1213 FunctionCallingConfig: &genai.FunctionCallingConfig{
1214 Mode: genai.FunctionCallingConfigModeAny,
1215 AllowedFunctionNames: []string{
1216 string(*toolChoice),
1217 },
1218 },
1219 }
1220 }
1221 return googleTools, googleToolChoice, warnings
1222}
1223
1224func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
1225 properties := make(map[string]*genai.Schema)
1226
1227 for name, param := range parameters {
1228 properties[name] = convertToSchema(param)
1229 }
1230
1231 return properties
1232}
1233
1234func convertToSchema(param any) *genai.Schema {
1235 schema := &genai.Schema{Type: genai.TypeString}
1236
1237 paramMap, ok := param.(map[string]any)
1238 if !ok {
1239 return schema
1240 }
1241
1242 if desc, ok := paramMap["description"].(string); ok {
1243 schema.Description = desc
1244 }
1245
1246 typeVal, hasType := paramMap["type"]
1247 if !hasType {
1248 return schema
1249 }
1250
1251 typeStr, ok := typeVal.(string)
1252 if !ok {
1253 return schema
1254 }
1255
1256 schema.Type = mapJSONTypeToGoogle(typeStr)
1257
1258 switch typeStr {
1259 case "array":
1260 schema.Items = processArrayItems(paramMap)
1261 case "object":
1262 if props, ok := paramMap["properties"].(map[string]any); ok {
1263 schema.Properties = convertSchemaProperties(props)
1264 }
1265 }
1266
1267 return schema
1268}
1269
1270func processArrayItems(paramMap map[string]any) *genai.Schema {
1271 items, ok := paramMap["items"].(map[string]any)
1272 if !ok {
1273 return nil
1274 }
1275
1276 return convertToSchema(items)
1277}
1278
1279func mapJSONTypeToGoogle(jsonType string) genai.Type {
1280 switch jsonType {
1281 case "string":
1282 return genai.TypeString
1283 case "number":
1284 return genai.TypeNumber
1285 case "integer":
1286 return genai.TypeInteger
1287 case "boolean":
1288 return genai.TypeBoolean
1289 case "array":
1290 return genai.TypeArray
1291 case "object":
1292 return genai.TypeObject
1293 default:
1294 return genai.TypeString // Default to string for unknown types
1295 }
1296}
1297
1298func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
1299 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
1300 return nil, errors.New("no response from model")
1301 }
1302
1303 var (
1304 content []fantasy.Content
1305 finishReason fantasy.FinishReason
1306 hasToolCalls bool
1307 candidate = response.Candidates[0]
1308 )
1309
1310 for _, part := range candidate.Content.Parts {
1311 switch {
1312 case part.Text != "":
1313 if part.Thought {
1314 reasoningContent := fantasy.ReasoningContent{Text: part.Text}
1315 if part.ThoughtSignature != nil {
1316 metadata := &ReasoningMetadata{
1317 Signature: string(part.ThoughtSignature),
1318 }
1319 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1320 Name: metadata,
1321 }
1322 }
1323 content = append(content, reasoningContent)
1324 } else {
1325 foundReasoning := false
1326 if part.ThoughtSignature != nil {
1327 metadata := &ReasoningMetadata{
1328 Signature: string(part.ThoughtSignature),
1329 }
1330 // find the last reasoning content and add the signature
1331 for i := len(content) - 1; i >= 0; i-- {
1332 c := content[i]
1333 if c.GetType() == fantasy.ContentTypeReasoning {
1334 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
1335 if !ok {
1336 continue
1337 }
1338 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1339 Name: metadata,
1340 }
1341 content[i] = reasoningContent
1342 foundReasoning = true
1343 break
1344 }
1345 }
1346 if !foundReasoning {
1347 content = append(content, fantasy.ReasoningContent{
1348 ProviderMetadata: fantasy.ProviderMetadata{
1349 Name: metadata,
1350 },
1351 })
1352 }
1353 }
1354 content = append(content, fantasy.TextContent{Text: part.Text})
1355 }
1356 case part.FunctionCall != nil:
1357 input, err := json.Marshal(part.FunctionCall.Args)
1358 if err != nil {
1359 return nil, err
1360 }
1361 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
1362 foundReasoning := false
1363 if part.ThoughtSignature != nil {
1364 metadata := &ReasoningMetadata{
1365 Signature: string(part.ThoughtSignature),
1366 ToolID: toolCallID,
1367 }
1368 // find the last reasoning content and add the signature
1369 for i := len(content) - 1; i >= 0; i-- {
1370 c := content[i]
1371 if c.GetType() == fantasy.ContentTypeReasoning {
1372 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
1373 if !ok {
1374 continue
1375 }
1376 reasoningContent.ProviderMetadata = fantasy.ProviderMetadata{
1377 Name: metadata,
1378 }
1379 content[i] = reasoningContent
1380 foundReasoning = true
1381 break
1382 }
1383 }
1384 if !foundReasoning {
1385 content = append(content, fantasy.ReasoningContent{
1386 ProviderMetadata: fantasy.ProviderMetadata{
1387 Name: metadata,
1388 },
1389 })
1390 }
1391 }
1392 content = append(content, fantasy.ToolCallContent{
1393 ToolCallID: toolCallID,
1394 ToolName: part.FunctionCall.Name,
1395 Input: string(input),
1396 ProviderExecuted: false,
1397 })
1398 hasToolCalls = true
1399 default:
1400 // Silently skip unknown part types instead of erroring
1401 // This allows for forward compatibility with new part types
1402 }
1403 }
1404
1405 if hasToolCalls {
1406 finishReason = fantasy.FinishReasonToolCalls
1407 } else {
1408 finishReason = mapFinishReason(candidate.FinishReason)
1409 }
1410
1411 return &fantasy.Response{
1412 Content: content,
1413 Usage: mapUsage(response.UsageMetadata),
1414 FinishReason: finishReason,
1415 Warnings: warnings,
1416 }, nil
1417}
1418
1419// GetReasoningMetadata extracts reasoning metadata from provider options for google models.
1420func GetReasoningMetadata(providerOptions fantasy.ProviderOptions) *ReasoningMetadata {
1421 if googleOptions, ok := providerOptions[Name]; ok {
1422 if reasoning, ok := googleOptions.(*ReasoningMetadata); ok {
1423 return reasoning
1424 }
1425 }
1426 return nil
1427}
1428
1429func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
1430 switch reason {
1431 case genai.FinishReasonStop:
1432 return fantasy.FinishReasonStop
1433 case genai.FinishReasonMaxTokens:
1434 return fantasy.FinishReasonLength
1435 case genai.FinishReasonSafety,
1436 genai.FinishReasonBlocklist,
1437 genai.FinishReasonProhibitedContent,
1438 genai.FinishReasonSPII,
1439 genai.FinishReasonImageSafety:
1440 return fantasy.FinishReasonContentFilter
1441 case genai.FinishReasonRecitation,
1442 genai.FinishReasonLanguage,
1443 genai.FinishReasonMalformedFunctionCall:
1444 return fantasy.FinishReasonError
1445 case genai.FinishReasonOther:
1446 return fantasy.FinishReasonOther
1447 default:
1448 return fantasy.FinishReasonUnknown
1449 }
1450}
1451
1452func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1453 return fantasy.Usage{
1454 InputTokens: int64(usage.PromptTokenCount),
1455 OutputTokens: int64(usage.CandidatesTokenCount),
1456 TotalTokens: int64(usage.TotalTokenCount),
1457 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1458 CacheCreationTokens: 0,
1459 CacheReadTokens: int64(usage.CachedContentTokenCount),
1460 }
1461}