1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "reflect"
12 "strings"
13
14 "charm.land/fantasy"
15 "charm.land/fantasy/object"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/schema"
18 "cloud.google.com/go/auth"
19 "github.com/charmbracelet/x/exp/slice"
20 "github.com/google/uuid"
21 "google.golang.org/genai"
22)
23
24// Name is the name of the Google provider.
25const Name = "google"
26
27type provider struct {
28 options options
29}
30
31// ToolCallIDFunc defines a function that generates a tool call ID.
32type ToolCallIDFunc = func() string
33
34type options struct {
35 apiKey string
36 name string
37 baseURL string
38 headers map[string]string
39 client *http.Client
40 backend genai.Backend
41 project string
42 location string
43 skipAuth bool
44 toolCallIDFunc ToolCallIDFunc
45 objectMode fantasy.ObjectMode
46}
47
48// Option defines a function that configures Google provider options.
49type Option = func(*options)
50
51// New creates a new Google provider with the given options.
52func New(opts ...Option) (fantasy.Provider, error) {
53 options := options{
54 headers: map[string]string{},
55 toolCallIDFunc: func() string {
56 return uuid.NewString()
57 },
58 }
59 for _, o := range opts {
60 o(&options)
61 }
62
63 options.name = cmp.Or(options.name, Name)
64
65 return &provider{
66 options: options,
67 }, nil
68}
69
70// WithBaseURL sets the base URL for the Google provider.
71func WithBaseURL(baseURL string) Option {
72 return func(o *options) {
73 o.baseURL = baseURL
74 }
75}
76
77// WithGeminiAPIKey sets the Gemini API key for the Google provider.
78func WithGeminiAPIKey(apiKey string) Option {
79 return func(o *options) {
80 o.backend = genai.BackendGeminiAPI
81 o.apiKey = apiKey
82 o.project = ""
83 o.location = ""
84 }
85}
86
87// WithVertex configures the Google provider to use Vertex AI.
88func WithVertex(project, location string) Option {
89 if project == "" || location == "" {
90 panic("project and location must be provided")
91 }
92 return func(o *options) {
93 o.backend = genai.BackendVertexAI
94 o.apiKey = ""
95 o.project = project
96 o.location = location
97 }
98}
99
100// WithSkipAuth configures whether to skip authentication for the Google provider.
101func WithSkipAuth(skipAuth bool) Option {
102 return func(o *options) {
103 o.skipAuth = skipAuth
104 }
105}
106
107// WithName sets the name for the Google provider.
108func WithName(name string) Option {
109 return func(o *options) {
110 o.name = name
111 }
112}
113
114// WithHeaders sets the headers for the Google provider.
115func WithHeaders(headers map[string]string) Option {
116 return func(o *options) {
117 maps.Copy(o.headers, headers)
118 }
119}
120
121// WithHTTPClient sets the HTTP client for the Google provider.
122func WithHTTPClient(client *http.Client) Option {
123 return func(o *options) {
124 o.client = client
125 }
126}
127
128// WithToolCallIDFunc sets the function that generates a tool call ID.
129func WithToolCallIDFunc(f ToolCallIDFunc) Option {
130 return func(o *options) {
131 o.toolCallIDFunc = f
132 }
133}
134
135// WithObjectMode sets the object generation mode for the Google provider.
136func WithObjectMode(om fantasy.ObjectMode) Option {
137 return func(o *options) {
138 o.objectMode = om
139 }
140}
141
142func (*provider) Name() string {
143 return Name
144}
145
146type languageModel struct {
147 provider string
148 modelID string
149 client *genai.Client
150 providerOptions options
151 objectMode fantasy.ObjectMode
152}
153
154// LanguageModel implements fantasy.Provider.
155func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
156 if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
157 p, err := anthropic.New(
158 anthropic.WithVertex(a.options.project, a.options.location),
159 anthropic.WithHTTPClient(a.options.client),
160 anthropic.WithSkipAuth(a.options.skipAuth),
161 )
162 if err != nil {
163 return nil, err
164 }
165 return p.LanguageModel(ctx, modelID)
166 }
167
168 cc := &genai.ClientConfig{
169 HTTPClient: a.options.client,
170 Backend: a.options.backend,
171 APIKey: a.options.apiKey,
172 Project: a.options.project,
173 Location: a.options.location,
174 }
175 if a.options.skipAuth {
176 cc.Credentials = &auth.Credentials{TokenProvider: dummyTokenProvider{}}
177 } else if cc.Backend == genai.BackendVertexAI {
178 if err := cc.UseDefaultCredentials(); err != nil {
179 return nil, err
180 }
181 }
182
183 if a.options.baseURL != "" || len(a.options.headers) > 0 {
184 headers := http.Header{}
185 for k, v := range a.options.headers {
186 headers.Add(k, v)
187 }
188 cc.HTTPOptions = genai.HTTPOptions{
189 BaseURL: a.options.baseURL,
190 Headers: headers,
191 }
192 }
193 client, err := genai.NewClient(ctx, cc)
194 if err != nil {
195 return nil, err
196 }
197
198 objectMode := a.options.objectMode
199 if objectMode == "" {
200 objectMode = fantasy.ObjectModeAuto
201 }
202
203 return &languageModel{
204 modelID: modelID,
205 provider: a.options.name,
206 providerOptions: a.options,
207 client: client,
208 objectMode: objectMode,
209 }, nil
210}
211
212func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentConfig, []*genai.Content, []fantasy.CallWarning, error) {
213 config := &genai.GenerateContentConfig{}
214
215 providerOptions := &ProviderOptions{}
216 if v, ok := call.ProviderOptions[Name]; ok {
217 providerOptions, ok = v.(*ProviderOptions)
218 if !ok {
219 return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"}
220 }
221 }
222
223 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
224
225 if providerOptions.ThinkingConfig != nil {
226 if providerOptions.ThinkingConfig.IncludeThoughts != nil &&
227 *providerOptions.ThinkingConfig.IncludeThoughts &&
228 strings.HasPrefix(g.provider, "google.vertex.") {
229 warnings = append(warnings, fantasy.CallWarning{
230 Type: fantasy.CallWarningTypeOther,
231 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
232 "and might not be supported or could behave unexpectedly with the current Google provider " +
233 fmt.Sprintf("(%s)", g.provider),
234 })
235 }
236
237 if providerOptions.ThinkingConfig.ThinkingBudget != nil &&
238 *providerOptions.ThinkingConfig.ThinkingBudget < 128 {
239 warnings = append(warnings, fantasy.CallWarning{
240 Type: fantasy.CallWarningTypeOther,
241 Message: "The 'thinking_budget' option can not be under 128 and will be set to 128 by default",
242 })
243 providerOptions.ThinkingConfig.ThinkingBudget = fantasy.Opt(int64(128))
244 }
245 }
246
247 isGemmaModel := strings.HasPrefix(strings.ToLower(g.modelID), "gemma-")
248
249 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
250 if len(content) > 0 && content[0].Role == genai.RoleUser {
251 systemParts := []string{}
252 for _, sp := range systemInstructions.Parts {
253 systemParts = append(systemParts, sp.Text)
254 }
255 systemMsg := strings.Join(systemParts, "\n")
256 content[0].Parts = append([]*genai.Part{
257 {
258 Text: systemMsg + "\n\n",
259 },
260 }, content[0].Parts...)
261 systemInstructions = nil
262 }
263 }
264
265 config.SystemInstruction = systemInstructions
266
267 if call.MaxOutputTokens != nil {
268 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
269 }
270
271 if call.Temperature != nil {
272 tmp := float32(*call.Temperature)
273 config.Temperature = &tmp
274 }
275 if call.TopK != nil {
276 tmp := float32(*call.TopK)
277 config.TopK = &tmp
278 }
279 if call.TopP != nil {
280 tmp := float32(*call.TopP)
281 config.TopP = &tmp
282 }
283 if call.FrequencyPenalty != nil {
284 tmp := float32(*call.FrequencyPenalty)
285 config.FrequencyPenalty = &tmp
286 }
287 if call.PresencePenalty != nil {
288 tmp := float32(*call.PresencePenalty)
289 config.PresencePenalty = &tmp
290 }
291
292 if providerOptions.ThinkingConfig != nil {
293 config.ThinkingConfig = &genai.ThinkingConfig{}
294 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
295 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
296 }
297 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
298 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
299 config.ThinkingConfig.ThinkingBudget = &tmp
300 }
301 }
302 for _, safetySetting := range providerOptions.SafetySettings {
303 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
304 Category: genai.HarmCategory(safetySetting.Category),
305 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
306 })
307 }
308 if providerOptions.CachedContent != "" {
309 config.CachedContent = providerOptions.CachedContent
310 }
311
312 if len(call.Tools) > 0 {
313 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
314 config.ToolConfig = toolChoice
315 config.Tools = append(config.Tools, &genai.Tool{
316 FunctionDeclarations: tools,
317 })
318 warnings = append(warnings, toolWarnings...)
319 }
320
321 return config, content, warnings, nil
322}
323
324func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, []fantasy.CallWarning) { //nolint: unparam
325 var systemInstructions *genai.Content
326 var content []*genai.Content
327 var warnings []fantasy.CallWarning
328
329 finishedSystemBlock := false
330 for _, msg := range prompt {
331 switch msg.Role {
332 case fantasy.MessageRoleSystem:
333 if finishedSystemBlock {
334 // skip multiple system messages that are separated by user/assistant messages
335 // TODO: see if we need to send error here?
336 continue
337 }
338 finishedSystemBlock = true
339
340 var systemMessages []string
341 for _, part := range msg.Content {
342 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
343 if !ok || text.Text == "" {
344 continue
345 }
346 systemMessages = append(systemMessages, text.Text)
347 }
348 if len(systemMessages) > 0 {
349 systemInstructions = &genai.Content{
350 Parts: []*genai.Part{
351 {
352 Text: strings.Join(systemMessages, "\n"),
353 },
354 },
355 }
356 }
357 case fantasy.MessageRoleUser:
358 var parts []*genai.Part
359 for _, part := range msg.Content {
360 switch part.GetType() {
361 case fantasy.ContentTypeText:
362 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
363 if !ok || text.Text == "" {
364 continue
365 }
366 parts = append(parts, &genai.Part{
367 Text: text.Text,
368 })
369 case fantasy.ContentTypeFile:
370 file, ok := fantasy.AsMessagePart[fantasy.FilePart](part)
371 if !ok {
372 continue
373 }
374 parts = append(parts, &genai.Part{
375 InlineData: &genai.Blob{
376 Data: file.Data,
377 MIMEType: file.MediaType,
378 },
379 })
380 }
381 }
382 if len(parts) > 0 {
383 content = append(content, &genai.Content{
384 Role: genai.RoleUser,
385 Parts: parts,
386 })
387 }
388 case fantasy.MessageRoleAssistant:
389 var parts []*genai.Part
390 // INFO: (kujtim) this is kind of a hacky way to include thinking for google
391 // weirdly thinking needs to be included in a function call
392 var signature []byte
393 for _, part := range msg.Content {
394 switch part.GetType() {
395 case fantasy.ContentTypeText:
396 text, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
397 if !ok || text.Text == "" {
398 continue
399 }
400 parts = append(parts, &genai.Part{
401 Text: text.Text,
402 })
403 case fantasy.ContentTypeToolCall:
404 toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
405 if !ok {
406 continue
407 }
408
409 var result map[string]any
410 err := json.Unmarshal([]byte(toolCall.Input), &result)
411 if err != nil {
412 continue
413 }
414 parts = append(parts, &genai.Part{
415 FunctionCall: &genai.FunctionCall{
416 ID: toolCall.ToolCallID,
417 Name: toolCall.ToolName,
418 Args: result,
419 },
420 ThoughtSignature: signature,
421 })
422 // reset
423 signature = nil
424 case fantasy.ContentTypeReasoning:
425 reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part)
426 if !ok {
427 continue
428 }
429 metadata, ok := reasoning.ProviderOptions[Name]
430 if !ok {
431 continue
432 }
433 reasoningMetadata, ok := metadata.(*ReasoningMetadata)
434 if !ok {
435 continue
436 }
437 if !ok || reasoningMetadata.Signature == "" {
438 continue
439 }
440 signature = []byte(reasoningMetadata.Signature)
441 }
442 }
443 if len(parts) > 0 {
444 content = append(content, &genai.Content{
445 Role: genai.RoleModel,
446 Parts: parts,
447 })
448 }
449 case fantasy.MessageRoleTool:
450 var parts []*genai.Part
451 for _, part := range msg.Content {
452 switch part.GetType() {
453 case fantasy.ContentTypeToolResult:
454 result, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
455 if !ok {
456 continue
457 }
458 var toolCall fantasy.ToolCallPart
459 for _, m := range prompt {
460 if m.Role == fantasy.MessageRoleAssistant {
461 for _, content := range m.Content {
462 tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](content)
463 if !ok {
464 continue
465 }
466 if tc.ToolCallID == result.ToolCallID {
467 toolCall = tc
468 break
469 }
470 }
471 }
472 }
473 switch result.Output.GetType() {
474 case fantasy.ToolResultContentTypeText:
475 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Output)
476 if !ok {
477 continue
478 }
479 response := map[string]any{"result": content.Text}
480 parts = append(parts, &genai.Part{
481 FunctionResponse: &genai.FunctionResponse{
482 ID: result.ToolCallID,
483 Response: response,
484 Name: toolCall.ToolName,
485 },
486 })
487
488 case fantasy.ToolResultContentTypeError:
489 content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
490 if !ok {
491 continue
492 }
493 response := map[string]any{"result": content.Error.Error()}
494 parts = append(parts, &genai.Part{
495 FunctionResponse: &genai.FunctionResponse{
496 ID: result.ToolCallID,
497 Response: response,
498 Name: toolCall.ToolName,
499 },
500 })
501 }
502 }
503 }
504 if len(parts) > 0 {
505 content = append(content, &genai.Content{
506 Role: genai.RoleUser,
507 Parts: parts,
508 })
509 }
510 default:
511 panic("unsupported message role: " + msg.Role)
512 }
513 }
514 return systemInstructions, content, warnings
515}
516
517// Generate implements fantasy.LanguageModel.
518func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
519 config, contents, warnings, err := g.prepareParams(call)
520 if err != nil {
521 return nil, err
522 }
523
524 lastMessage, history, ok := slice.Pop(contents)
525 if !ok {
526 return nil, errors.New("no messages to send")
527 }
528
529 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
530 if err != nil {
531 return nil, err
532 }
533
534 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
535 if err != nil {
536 return nil, toProviderErr(err)
537 }
538
539 return g.mapResponse(response, warnings)
540}
541
542// Model implements fantasy.LanguageModel.
543func (g *languageModel) Model() string {
544 return g.modelID
545}
546
547// Provider implements fantasy.LanguageModel.
548func (g *languageModel) Provider() string {
549 return g.provider
550}
551
552// Stream implements fantasy.LanguageModel.
553func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
554 config, contents, warnings, err := g.prepareParams(call)
555 if err != nil {
556 return nil, err
557 }
558
559 lastMessage, history, ok := slice.Pop(contents)
560 if !ok {
561 return nil, errors.New("no messages to send")
562 }
563
564 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
565 if err != nil {
566 return nil, err
567 }
568
569 return func(yield func(fantasy.StreamPart) bool) {
570 if len(warnings) > 0 {
571 if !yield(fantasy.StreamPart{
572 Type: fantasy.StreamPartTypeWarnings,
573 Warnings: warnings,
574 }) {
575 return
576 }
577 }
578
579 var currentContent string
580 var toolCalls []fantasy.ToolCallContent
581 var isActiveText bool
582 var isActiveReasoning bool
583 var blockCounter int
584 var currentTextBlockID string
585 var currentReasoningBlockID string
586 var usage *fantasy.Usage
587 var lastFinishReason fantasy.FinishReason
588
589 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
590 if err != nil {
591 yield(fantasy.StreamPart{
592 Type: fantasy.StreamPartTypeError,
593 Error: toProviderErr(err),
594 })
595 return
596 }
597
598 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
599 for _, part := range resp.Candidates[0].Content.Parts {
600 switch {
601 case part.Text != "":
602 delta := part.Text
603 if delta != "" {
604 // Check if this is a reasoning/thought part
605 if part.Thought {
606 // End any active text block before starting reasoning
607 if isActiveText {
608 isActiveText = false
609 if !yield(fantasy.StreamPart{
610 Type: fantasy.StreamPartTypeTextEnd,
611 ID: currentTextBlockID,
612 }) {
613 return
614 }
615 }
616
617 // Start new reasoning block if not already active
618 if !isActiveReasoning {
619 isActiveReasoning = true
620 currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
621 blockCounter++
622 if !yield(fantasy.StreamPart{
623 Type: fantasy.StreamPartTypeReasoningStart,
624 ID: currentReasoningBlockID,
625 }) {
626 return
627 }
628 }
629
630 if !yield(fantasy.StreamPart{
631 Type: fantasy.StreamPartTypeReasoningDelta,
632 ID: currentReasoningBlockID,
633 Delta: delta,
634 }) {
635 return
636 }
637 } else {
638 // Regular text part
639 // End any active reasoning block before starting text
640 if isActiveReasoning {
641 isActiveReasoning = false
642 metadata := &ReasoningMetadata{
643 Signature: string(part.ThoughtSignature),
644 }
645 if !yield(fantasy.StreamPart{
646 Type: fantasy.StreamPartTypeReasoningEnd,
647 ID: currentReasoningBlockID,
648 ProviderMetadata: fantasy.ProviderMetadata{
649 Name: metadata,
650 },
651 }) {
652 return
653 }
654 }
655
656 // Start new text block if not already active
657 if !isActiveText {
658 isActiveText = true
659 currentTextBlockID = fmt.Sprintf("%d", blockCounter)
660 blockCounter++
661 if !yield(fantasy.StreamPart{
662 Type: fantasy.StreamPartTypeTextStart,
663 ID: currentTextBlockID,
664 }) {
665 return
666 }
667 }
668
669 if !yield(fantasy.StreamPart{
670 Type: fantasy.StreamPartTypeTextDelta,
671 ID: currentTextBlockID,
672 Delta: delta,
673 }) {
674 return
675 }
676 currentContent += delta
677 }
678 }
679 case part.FunctionCall != nil:
680 // End any active text or reasoning blocks
681 if isActiveText {
682 isActiveText = false
683 if !yield(fantasy.StreamPart{
684 Type: fantasy.StreamPartTypeTextEnd,
685 ID: currentTextBlockID,
686 }) {
687 return
688 }
689 }
690 if isActiveReasoning {
691 isActiveReasoning = false
692
693 metadata := &ReasoningMetadata{
694 Signature: string(part.ThoughtSignature),
695 }
696 if !yield(fantasy.StreamPart{
697 Type: fantasy.StreamPartTypeReasoningEnd,
698 ID: currentReasoningBlockID,
699 ProviderMetadata: fantasy.ProviderMetadata{
700 Name: metadata,
701 },
702 }) {
703 return
704 }
705 }
706
707 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
708
709 args, err := json.Marshal(part.FunctionCall.Args)
710 if err != nil {
711 yield(fantasy.StreamPart{
712 Type: fantasy.StreamPartTypeError,
713 Error: err,
714 })
715 return
716 }
717
718 if !yield(fantasy.StreamPart{
719 Type: fantasy.StreamPartTypeToolInputStart,
720 ID: toolCallID,
721 ToolCallName: part.FunctionCall.Name,
722 }) {
723 return
724 }
725
726 if !yield(fantasy.StreamPart{
727 Type: fantasy.StreamPartTypeToolInputDelta,
728 ID: toolCallID,
729 Delta: string(args),
730 }) {
731 return
732 }
733
734 if !yield(fantasy.StreamPart{
735 Type: fantasy.StreamPartTypeToolInputEnd,
736 ID: toolCallID,
737 }) {
738 return
739 }
740
741 if !yield(fantasy.StreamPart{
742 Type: fantasy.StreamPartTypeToolCall,
743 ID: toolCallID,
744 ToolCallName: part.FunctionCall.Name,
745 ToolCallInput: string(args),
746 ProviderExecuted: false,
747 }) {
748 return
749 }
750
751 toolCalls = append(toolCalls, fantasy.ToolCallContent{
752 ToolCallID: toolCallID,
753 ToolName: part.FunctionCall.Name,
754 Input: string(args),
755 ProviderExecuted: false,
756 })
757 }
758 }
759 }
760
761 // we need to make sure that there is actual tokendata
762 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
763 currentUsage := mapUsage(resp.UsageMetadata)
764 // if first usage chunk
765 if usage == nil {
766 usage = ¤tUsage
767 } else {
768 usage.OutputTokens += currentUsage.OutputTokens
769 usage.ReasoningTokens += currentUsage.ReasoningTokens
770 usage.CacheReadTokens += currentUsage.CacheReadTokens
771 }
772 }
773
774 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
775 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
776 }
777 }
778
779 // Close any open blocks before finishing
780 if isActiveText {
781 if !yield(fantasy.StreamPart{
782 Type: fantasy.StreamPartTypeTextEnd,
783 ID: currentTextBlockID,
784 }) {
785 return
786 }
787 }
788 if isActiveReasoning {
789 if !yield(fantasy.StreamPart{
790 Type: fantasy.StreamPartTypeReasoningEnd,
791 ID: currentReasoningBlockID,
792 }) {
793 return
794 }
795 }
796
797 finishReason := lastFinishReason
798 if len(toolCalls) > 0 {
799 finishReason = fantasy.FinishReasonToolCalls
800 } else if finishReason == "" {
801 finishReason = fantasy.FinishReasonStop
802 }
803
804 yield(fantasy.StreamPart{
805 Type: fantasy.StreamPartTypeFinish,
806 Usage: *usage,
807 FinishReason: finishReason,
808 })
809 }, nil
810}
811
812// GenerateObject implements fantasy.LanguageModel.
813func (g *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
814 switch g.objectMode {
815 case fantasy.ObjectModeText:
816 return object.GenerateWithText(ctx, g, call)
817 case fantasy.ObjectModeTool:
818 return object.GenerateWithTool(ctx, g, call)
819 default:
820 return g.generateObjectWithJSONMode(ctx, call)
821 }
822}
823
824// StreamObject implements fantasy.LanguageModel.
825func (g *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
826 switch g.objectMode {
827 case fantasy.ObjectModeTool:
828 return object.StreamWithTool(ctx, g, call)
829 case fantasy.ObjectModeText:
830 return object.StreamWithText(ctx, g, call)
831 default:
832 return g.streamObjectWithJSONMode(ctx, call)
833 }
834}
835
836func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
837 // Convert our Schema to Google's JSON Schema format
838 jsonSchemaMap := schema.ToMap(call.Schema)
839
840 // Build request using prepareParams
841 fantasyCall := fantasy.Call{
842 Prompt: call.Prompt,
843 MaxOutputTokens: call.MaxOutputTokens,
844 Temperature: call.Temperature,
845 TopP: call.TopP,
846 TopK: call.TopK,
847 PresencePenalty: call.PresencePenalty,
848 FrequencyPenalty: call.FrequencyPenalty,
849 ProviderOptions: call.ProviderOptions,
850 }
851
852 config, contents, warnings, err := g.prepareParams(fantasyCall)
853 if err != nil {
854 return nil, err
855 }
856
857 // Set ResponseMIMEType and ResponseJsonSchema for structured output
858 config.ResponseMIMEType = "application/json"
859 config.ResponseJsonSchema = jsonSchemaMap
860
861 lastMessage, history, ok := slice.Pop(contents)
862 if !ok {
863 return nil, errors.New("no messages to send")
864 }
865
866 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
867 if err != nil {
868 return nil, err
869 }
870
871 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
872 if err != nil {
873 return nil, toProviderErr(err)
874 }
875
876 mappedResponse, err := g.mapResponse(response, warnings)
877 if err != nil {
878 return nil, err
879 }
880
881 jsonText := mappedResponse.Content.Text()
882 if jsonText == "" {
883 return nil, &fantasy.NoObjectGeneratedError{
884 RawText: "",
885 ParseError: fmt.Errorf("no text content in response"),
886 Usage: mappedResponse.Usage,
887 FinishReason: mappedResponse.FinishReason,
888 }
889 }
890
891 // Parse and validate
892 var obj any
893 if call.RepairText != nil {
894 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
895 } else {
896 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
897 }
898
899 if err != nil {
900 // Add usage info to error
901 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
902 nogErr.Usage = mappedResponse.Usage
903 nogErr.FinishReason = mappedResponse.FinishReason
904 }
905 return nil, err
906 }
907
908 return &fantasy.ObjectResponse{
909 Object: obj,
910 RawText: jsonText,
911 Usage: mappedResponse.Usage,
912 FinishReason: mappedResponse.FinishReason,
913 Warnings: warnings,
914 ProviderMetadata: mappedResponse.ProviderMetadata,
915 }, nil
916}
917
918func (g *languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
919 // Convert our Schema to Google's JSON Schema format
920 jsonSchemaMap := schema.ToMap(call.Schema)
921
922 // Build request using prepareParams
923 fantasyCall := fantasy.Call{
924 Prompt: call.Prompt,
925 MaxOutputTokens: call.MaxOutputTokens,
926 Temperature: call.Temperature,
927 TopP: call.TopP,
928 TopK: call.TopK,
929 PresencePenalty: call.PresencePenalty,
930 FrequencyPenalty: call.FrequencyPenalty,
931 ProviderOptions: call.ProviderOptions,
932 }
933
934 config, contents, warnings, err := g.prepareParams(fantasyCall)
935 if err != nil {
936 return nil, err
937 }
938
939 // Set ResponseMIMEType and ResponseJsonSchema for structured output
940 config.ResponseMIMEType = "application/json"
941 config.ResponseJsonSchema = jsonSchemaMap
942
943 lastMessage, history, ok := slice.Pop(contents)
944 if !ok {
945 return nil, errors.New("no messages to send")
946 }
947
948 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
949 if err != nil {
950 return nil, err
951 }
952
953 return func(yield func(fantasy.ObjectStreamPart) bool) {
954 if len(warnings) > 0 {
955 if !yield(fantasy.ObjectStreamPart{
956 Type: fantasy.ObjectStreamPartTypeObject,
957 Warnings: warnings,
958 }) {
959 return
960 }
961 }
962
963 var accumulated string
964 var lastParsedObject any
965 var usage *fantasy.Usage
966 var lastFinishReason fantasy.FinishReason
967 var streamErr error
968
969 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
970 if err != nil {
971 streamErr = toProviderErr(err)
972 yield(fantasy.ObjectStreamPart{
973 Type: fantasy.ObjectStreamPartTypeError,
974 Error: streamErr,
975 })
976 return
977 }
978
979 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
980 for _, part := range resp.Candidates[0].Content.Parts {
981 if part.Text != "" && !part.Thought {
982 accumulated += part.Text
983
984 // Try to parse the accumulated text
985 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
986
987 // If we successfully parsed, validate and emit
988 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
989 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
990 // Only emit if object is different from last
991 if !reflect.DeepEqual(obj, lastParsedObject) {
992 if !yield(fantasy.ObjectStreamPart{
993 Type: fantasy.ObjectStreamPartTypeObject,
994 Object: obj,
995 }) {
996 return
997 }
998 lastParsedObject = obj
999 }
1000 }
1001 }
1002
1003 // If parsing failed and we have a repair function, try it
1004 if state == schema.ParseStateFailed && call.RepairText != nil {
1005 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
1006 if repairErr == nil {
1007 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
1008 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
1009 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
1010 if !reflect.DeepEqual(obj2, lastParsedObject) {
1011 if !yield(fantasy.ObjectStreamPart{
1012 Type: fantasy.ObjectStreamPartTypeObject,
1013 Object: obj2,
1014 }) {
1015 return
1016 }
1017 lastParsedObject = obj2
1018 }
1019 }
1020 }
1021 }
1022 }
1023 }
1024 }
1025
1026 // we need to make sure that there is actual tokendata
1027 if resp.UsageMetadata != nil && resp.UsageMetadata.TotalTokenCount != 0 {
1028 currentUsage := mapUsage(resp.UsageMetadata)
1029 if usage == nil {
1030 usage = ¤tUsage
1031 } else {
1032 usage.OutputTokens += currentUsage.OutputTokens
1033 usage.ReasoningTokens += currentUsage.ReasoningTokens
1034 usage.CacheReadTokens += currentUsage.CacheReadTokens
1035 }
1036 }
1037
1038 if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
1039 lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
1040 }
1041 }
1042
1043 // Final validation and emit
1044 if streamErr == nil && lastParsedObject != nil {
1045 finishReason := lastFinishReason
1046 if finishReason == "" {
1047 finishReason = fantasy.FinishReasonStop
1048 }
1049
1050 yield(fantasy.ObjectStreamPart{
1051 Type: fantasy.ObjectStreamPartTypeFinish,
1052 Usage: *usage,
1053 FinishReason: finishReason,
1054 })
1055 } else if streamErr == nil && lastParsedObject == nil {
1056 // No object was generated
1057 finalUsage := fantasy.Usage{}
1058 if usage != nil {
1059 finalUsage = *usage
1060 }
1061 yield(fantasy.ObjectStreamPart{
1062 Type: fantasy.ObjectStreamPartTypeError,
1063 Error: &fantasy.NoObjectGeneratedError{
1064 RawText: accumulated,
1065 ParseError: fmt.Errorf("no valid object generated in stream"),
1066 Usage: finalUsage,
1067 FinishReason: lastFinishReason,
1068 },
1069 })
1070 }
1071 }, nil
1072}
1073
1074func toGoogleTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []fantasy.CallWarning) {
1075 for _, tool := range tools {
1076 if tool.GetType() == fantasy.ToolTypeFunction {
1077 ft, ok := tool.(fantasy.FunctionTool)
1078 if !ok {
1079 continue
1080 }
1081
1082 required := []string{}
1083 var properties map[string]any
1084 if props, ok := ft.InputSchema["properties"]; ok {
1085 properties, _ = props.(map[string]any)
1086 }
1087 if req, ok := ft.InputSchema["required"]; ok {
1088 if reqArr, ok := req.([]string); ok {
1089 required = reqArr
1090 }
1091 }
1092 declaration := &genai.FunctionDeclaration{
1093 Name: ft.Name,
1094 Description: ft.Description,
1095 Parameters: &genai.Schema{
1096 Type: genai.TypeObject,
1097 Properties: convertSchemaProperties(properties),
1098 Required: required,
1099 },
1100 }
1101 googleTools = append(googleTools, declaration)
1102 continue
1103 }
1104 // TODO: handle provider tool calls
1105 warnings = append(warnings, fantasy.CallWarning{
1106 Type: fantasy.CallWarningTypeUnsupportedTool,
1107 Tool: tool,
1108 Message: "tool is not supported",
1109 })
1110 }
1111 if toolChoice == nil {
1112 return googleTools, googleToolChoice, warnings
1113 }
1114 switch *toolChoice {
1115 case fantasy.ToolChoiceAuto:
1116 googleToolChoice = &genai.ToolConfig{
1117 FunctionCallingConfig: &genai.FunctionCallingConfig{
1118 Mode: genai.FunctionCallingConfigModeAuto,
1119 },
1120 }
1121 case fantasy.ToolChoiceRequired:
1122 googleToolChoice = &genai.ToolConfig{
1123 FunctionCallingConfig: &genai.FunctionCallingConfig{
1124 Mode: genai.FunctionCallingConfigModeAny,
1125 },
1126 }
1127 case fantasy.ToolChoiceNone:
1128 googleToolChoice = &genai.ToolConfig{
1129 FunctionCallingConfig: &genai.FunctionCallingConfig{
1130 Mode: genai.FunctionCallingConfigModeNone,
1131 },
1132 }
1133 default:
1134 googleToolChoice = &genai.ToolConfig{
1135 FunctionCallingConfig: &genai.FunctionCallingConfig{
1136 Mode: genai.FunctionCallingConfigModeAny,
1137 AllowedFunctionNames: []string{
1138 string(*toolChoice),
1139 },
1140 },
1141 }
1142 }
1143 return googleTools, googleToolChoice, warnings
1144}
1145
1146func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
1147 properties := make(map[string]*genai.Schema)
1148
1149 for name, param := range parameters {
1150 properties[name] = convertToSchema(param)
1151 }
1152
1153 return properties
1154}
1155
1156func convertToSchema(param any) *genai.Schema {
1157 schema := &genai.Schema{Type: genai.TypeString}
1158
1159 paramMap, ok := param.(map[string]any)
1160 if !ok {
1161 return schema
1162 }
1163
1164 if desc, ok := paramMap["description"].(string); ok {
1165 schema.Description = desc
1166 }
1167
1168 typeVal, hasType := paramMap["type"]
1169 if !hasType {
1170 return schema
1171 }
1172
1173 typeStr, ok := typeVal.(string)
1174 if !ok {
1175 return schema
1176 }
1177
1178 schema.Type = mapJSONTypeToGoogle(typeStr)
1179
1180 switch typeStr {
1181 case "array":
1182 schema.Items = processArrayItems(paramMap)
1183 case "object":
1184 if props, ok := paramMap["properties"].(map[string]any); ok {
1185 schema.Properties = convertSchemaProperties(props)
1186 }
1187 }
1188
1189 return schema
1190}
1191
1192func processArrayItems(paramMap map[string]any) *genai.Schema {
1193 items, ok := paramMap["items"].(map[string]any)
1194 if !ok {
1195 return nil
1196 }
1197
1198 return convertToSchema(items)
1199}
1200
1201func mapJSONTypeToGoogle(jsonType string) genai.Type {
1202 switch jsonType {
1203 case "string":
1204 return genai.TypeString
1205 case "number":
1206 return genai.TypeNumber
1207 case "integer":
1208 return genai.TypeInteger
1209 case "boolean":
1210 return genai.TypeBoolean
1211 case "array":
1212 return genai.TypeArray
1213 case "object":
1214 return genai.TypeObject
1215 default:
1216 return genai.TypeString // Default to string for unknown types
1217 }
1218}
1219
1220func (g languageModel) mapResponse(response *genai.GenerateContentResponse, warnings []fantasy.CallWarning) (*fantasy.Response, error) {
1221 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
1222 return nil, errors.New("no response from model")
1223 }
1224
1225 var (
1226 content []fantasy.Content
1227 finishReason fantasy.FinishReason
1228 hasToolCalls bool
1229 candidate = response.Candidates[0]
1230 )
1231
1232 for _, part := range candidate.Content.Parts {
1233 switch {
1234 case part.Text != "":
1235 if part.Thought {
1236 metadata := &ReasoningMetadata{
1237 Signature: string(part.ThoughtSignature),
1238 }
1239 content = append(content, fantasy.ReasoningContent{Text: part.Text, ProviderMetadata: fantasy.ProviderMetadata{Name: metadata}})
1240 } else {
1241 content = append(content, fantasy.TextContent{Text: part.Text})
1242 }
1243 case part.FunctionCall != nil:
1244 input, err := json.Marshal(part.FunctionCall.Args)
1245 if err != nil {
1246 return nil, err
1247 }
1248 toolCallID := cmp.Or(part.FunctionCall.ID, g.providerOptions.toolCallIDFunc())
1249 content = append(content, fantasy.ToolCallContent{
1250 ToolCallID: toolCallID,
1251 ToolName: part.FunctionCall.Name,
1252 Input: string(input),
1253 ProviderExecuted: false,
1254 })
1255 hasToolCalls = true
1256 default:
1257 // Silently skip unknown part types instead of erroring
1258 // This allows for forward compatibility with new part types
1259 }
1260 }
1261
1262 if hasToolCalls {
1263 finishReason = fantasy.FinishReasonToolCalls
1264 } else {
1265 finishReason = mapFinishReason(candidate.FinishReason)
1266 }
1267
1268 return &fantasy.Response{
1269 Content: content,
1270 Usage: mapUsage(response.UsageMetadata),
1271 FinishReason: finishReason,
1272 Warnings: warnings,
1273 }, nil
1274}
1275
1276func mapFinishReason(reason genai.FinishReason) fantasy.FinishReason {
1277 switch reason {
1278 case genai.FinishReasonStop:
1279 return fantasy.FinishReasonStop
1280 case genai.FinishReasonMaxTokens:
1281 return fantasy.FinishReasonLength
1282 case genai.FinishReasonSafety,
1283 genai.FinishReasonBlocklist,
1284 genai.FinishReasonProhibitedContent,
1285 genai.FinishReasonSPII,
1286 genai.FinishReasonImageSafety:
1287 return fantasy.FinishReasonContentFilter
1288 case genai.FinishReasonRecitation,
1289 genai.FinishReasonLanguage,
1290 genai.FinishReasonMalformedFunctionCall:
1291 return fantasy.FinishReasonError
1292 case genai.FinishReasonOther:
1293 return fantasy.FinishReasonOther
1294 default:
1295 return fantasy.FinishReasonUnknown
1296 }
1297}
1298
1299func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) fantasy.Usage {
1300 return fantasy.Usage{
1301 InputTokens: int64(usage.PromptTokenCount),
1302 OutputTokens: int64(usage.CandidatesTokenCount),
1303 TotalTokens: int64(usage.TotalTokenCount),
1304 ReasoningTokens: int64(usage.ThoughtsTokenCount),
1305 CacheCreationTokens: 0,
1306 CacheReadTokens: int64(usage.CachedContentTokenCount),
1307 }
1308}