1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 "github.com/charmbracelet/x/exp/slice"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type provider struct {
21 options options
22}
23
24type options struct {
25 apiKey string
26 name string
27 headers map[string]string
28 client *http.Client
29}
30
31type Option = func(*options)
32
33func New(opts ...Option) ai.Provider {
34 options := options{
35 headers: map[string]string{},
36 }
37 for _, o := range opts {
38 o(&options)
39 }
40
41 if options.name == "" {
42 options.name = "google"
43 }
44
45 return &provider{
46 options: options,
47 }
48}
49
50func WithAPIKey(apiKey string) Option {
51 return func(o *options) {
52 o.apiKey = apiKey
53 }
54}
55
56func WithName(name string) Option {
57 return func(o *options) {
58 o.name = name
59 }
60}
61
62func WithHeaders(headers map[string]string) Option {
63 return func(o *options) {
64 maps.Copy(o.headers, headers)
65 }
66}
67
68func WithHTTPClient(client *http.Client) Option {
69 return func(o *options) {
70 o.client = client
71 }
72}
73
74type languageModel struct {
75 provider string
76 modelID string
77 client *genai.Client
78 providerOptions options
79}
80
81// LanguageModel implements ai.Provider.
82func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
83 cc := &genai.ClientConfig{
84 APIKey: g.options.apiKey,
85 Backend: genai.BackendGeminiAPI,
86 HTTPClient: g.options.client,
87 }
88 client, err := genai.NewClient(context.Background(), cc)
89 if err != nil {
90 return nil, err
91 }
92 return &languageModel{
93 modelID: modelID,
94 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
95 providerOptions: g.options,
96 client: client,
97 }, nil
98}
99
100func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
101 config := &genai.GenerateContentConfig{}
102 providerOptions := &providerOptions{}
103 if v, ok := call.ProviderOptions["google"]; ok {
104 err := ai.ParseOptions(v, providerOptions)
105 if err != nil {
106 return nil, nil, nil, err
107 }
108 }
109
110 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
111
112 if providerOptions.ThinkingConfig != nil &&
113 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
114 *providerOptions.ThinkingConfig.IncludeThoughts &&
115 strings.HasPrefix(a.provider, "google.vertex.") {
116 warnings = append(warnings, ai.CallWarning{
117 Type: ai.CallWarningTypeOther,
118 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
119 "and might not be supported or could behave unexpectedly with the current Google provider " +
120 fmt.Sprintf("(%s)", a.provider),
121 })
122 }
123
124 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
125
126 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
127 if len(content) > 0 && content[0].Role == genai.RoleUser {
128 systemParts := []string{}
129 for _, sp := range systemInstructions.Parts {
130 systemParts = append(systemParts, sp.Text)
131 }
132 systemMsg := strings.Join(systemParts, "\n")
133 content[0].Parts = append([]*genai.Part{
134 {
135 Text: systemMsg + "\n\n",
136 },
137 }, content[0].Parts...)
138 systemInstructions = nil
139 }
140 }
141
142 config.SystemInstruction = systemInstructions
143
144 if call.MaxOutputTokens != nil {
145 config.MaxOutputTokens = int32(*call.MaxOutputTokens)
146 }
147
148 if call.Temperature != nil {
149 tmp := float32(*call.Temperature)
150 config.Temperature = &tmp
151 }
152 if call.TopK != nil {
153 tmp := float32(*call.TopK)
154 config.TopK = &tmp
155 }
156 if call.TopP != nil {
157 tmp := float32(*call.TopP)
158 config.TopP = &tmp
159 }
160 if call.FrequencyPenalty != nil {
161 tmp := float32(*call.FrequencyPenalty)
162 config.FrequencyPenalty = &tmp
163 }
164 if call.PresencePenalty != nil {
165 tmp := float32(*call.PresencePenalty)
166 config.PresencePenalty = &tmp
167 }
168
169 if providerOptions.ThinkingConfig != nil {
170 config.ThinkingConfig = &genai.ThinkingConfig{}
171 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
172 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
173 }
174 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
175 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
176 config.ThinkingConfig.ThinkingBudget = &tmp
177 }
178 }
179 for _, safetySetting := range providerOptions.SafetySettings {
180 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
181 Category: genai.HarmCategory(safetySetting.Category),
182 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
183 })
184 }
185 if providerOptions.CachedContent != "" {
186 config.CachedContent = providerOptions.CachedContent
187 }
188
189 if len(call.Tools) > 0 {
190 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
191 config.ToolConfig = toolChoice
192 config.Tools = append(config.Tools, &genai.Tool{
193 FunctionDeclarations: tools,
194 })
195 warnings = append(warnings, toolWarnings...)
196 }
197
198 return config, content, warnings, nil
199}
200
201func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
202 var systemInstructions *genai.Content
203 var content []*genai.Content
204 var warnings []ai.CallWarning
205
206 finishedSystemBlock := false
207 for _, msg := range prompt {
208 switch msg.Role {
209 case ai.MessageRoleSystem:
210 if finishedSystemBlock {
211 // skip multiple system messages that are separated by user/assistant messages
212 // TODO: see if we need to send error here?
213 continue
214 }
215 finishedSystemBlock = true
216
217 var systemMessages []string
218 for _, part := range msg.Content {
219 text, ok := ai.AsMessagePart[ai.TextPart](part)
220 if !ok || text.Text == "" {
221 continue
222 }
223 systemMessages = append(systemMessages, text.Text)
224 }
225 if len(systemMessages) > 0 {
226 systemInstructions = &genai.Content{
227 Parts: []*genai.Part{
228 {
229 Text: strings.Join(systemMessages, "\n"),
230 },
231 },
232 }
233 }
234 case ai.MessageRoleUser:
235 var parts []*genai.Part
236 for _, part := range msg.Content {
237 switch part.GetType() {
238 case ai.ContentTypeText:
239 text, ok := ai.AsMessagePart[ai.TextPart](part)
240 if !ok || text.Text == "" {
241 continue
242 }
243 parts = append(parts, &genai.Part{
244 Text: text.Text,
245 })
246 case ai.ContentTypeFile:
247 file, ok := ai.AsMessagePart[ai.FilePart](part)
248 if !ok {
249 continue
250 }
251 var encoded []byte
252 base64.StdEncoding.Encode(encoded, file.Data)
253 parts = append(parts, &genai.Part{
254 InlineData: &genai.Blob{
255 Data: encoded,
256 MIMEType: file.MediaType,
257 },
258 })
259 }
260 }
261 if len(parts) > 0 {
262 content = append(content, &genai.Content{
263 Role: genai.RoleUser,
264 Parts: parts,
265 })
266 }
267 case ai.MessageRoleAssistant:
268 var parts []*genai.Part
269 for _, part := range msg.Content {
270 switch part.GetType() {
271 case ai.ContentTypeText:
272 text, ok := ai.AsMessagePart[ai.TextPart](part)
273 if !ok || text.Text == "" {
274 continue
275 }
276 parts = append(parts, &genai.Part{
277 Text: text.Text,
278 })
279 case ai.ContentTypeToolCall:
280 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
281 if !ok {
282 continue
283 }
284
285 var result map[string]any
286 err := json.Unmarshal([]byte(toolCall.Input), &result)
287 if err != nil {
288 continue
289 }
290 parts = append(parts, &genai.Part{
291 FunctionCall: &genai.FunctionCall{
292 ID: toolCall.ToolCallID,
293 Name: toolCall.ToolName,
294 Args: result,
295 },
296 })
297 }
298 }
299 if len(parts) > 0 {
300 content = append(content, &genai.Content{
301 Role: genai.RoleModel,
302 Parts: parts,
303 })
304 }
305 case ai.MessageRoleTool:
306 var parts []*genai.Part
307 for _, part := range msg.Content {
308 switch part.GetType() {
309 case ai.ContentTypeToolResult:
310 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
311 if !ok {
312 continue
313 }
314 var toolCall ai.ToolCallPart
315 for _, m := range prompt {
316 if m.Role == ai.MessageRoleAssistant {
317 for _, content := range m.Content {
318 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
319 if !ok {
320 continue
321 }
322 if tc.ToolCallID == result.ToolCallID {
323 toolCall = tc
324 break
325 }
326 }
327 }
328 }
329 switch result.Output.GetType() {
330 case ai.ToolResultContentTypeText:
331 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
332 if !ok {
333 continue
334 }
335 response := map[string]any{"result": content.Text}
336 parts = append(parts, &genai.Part{
337 FunctionResponse: &genai.FunctionResponse{
338 ID: result.ToolCallID,
339 Response: response,
340 Name: toolCall.ToolName,
341 },
342 })
343
344 case ai.ToolResultContentTypeError:
345 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
346 if !ok {
347 continue
348 }
349 response := map[string]any{"result": content.Error.Error()}
350 parts = append(parts, &genai.Part{
351 FunctionResponse: &genai.FunctionResponse{
352 ID: result.ToolCallID,
353 Response: response,
354 Name: toolCall.ToolName,
355 },
356 })
357
358 }
359 }
360 }
361 if len(parts) > 0 {
362 content = append(content, &genai.Content{
363 Role: genai.RoleUser,
364 Parts: parts,
365 })
366 }
367 }
368 }
369 return systemInstructions, content, warnings
370}
371
372// Generate implements ai.LanguageModel.
373func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
374 config, contents, warnings, err := g.prepareParams(call)
375 if err != nil {
376 return nil, err
377 }
378
379 lastMessage, history, ok := slice.Pop(contents)
380 if !ok {
381 return nil, errors.New("no messages to send")
382 }
383
384 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
385 if err != nil {
386 return nil, err
387 }
388
389 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
390 if err != nil {
391 return nil, err
392 }
393
394 return mapResponse(response, warnings)
395}
396
397// Model implements ai.LanguageModel.
398func (g *languageModel) Model() string {
399 return g.modelID
400}
401
402// Provider implements ai.LanguageModel.
403func (g *languageModel) Provider() string {
404 return g.provider
405}
406
407// Stream implements ai.LanguageModel.
408func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
409 config, contents, warnings, err := g.prepareParams(call)
410 if err != nil {
411 return nil, err
412 }
413
414 lastMessage, history, ok := slice.Pop(contents)
415 if !ok {
416 return nil, errors.New("no messages to send")
417 }
418
419 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
420 if err != nil {
421 return nil, err
422 }
423
424 return func(yield func(ai.StreamPart) bool) {
425 if len(warnings) > 0 {
426 if !yield(ai.StreamPart{
427 Type: ai.StreamPartTypeWarnings,
428 Warnings: warnings,
429 }) {
430 return
431 }
432 }
433
434 var currentContent string
435 var toolCalls []ai.ToolCallContent
436 var isActiveText bool
437 var usage ai.Usage
438
439 // Stream the response
440 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
441 if err != nil {
442 yield(ai.StreamPart{
443 Type: ai.StreamPartTypeError,
444 Error: err,
445 })
446 return
447 }
448
449 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
450 for _, part := range resp.Candidates[0].Content.Parts {
451 switch {
452 case part.Text != "":
453 delta := part.Text
454 if delta != "" {
455 if !isActiveText {
456 isActiveText = true
457 if !yield(ai.StreamPart{
458 Type: ai.StreamPartTypeTextStart,
459 ID: "0",
460 }) {
461 return
462 }
463 }
464 if !yield(ai.StreamPart{
465 Type: ai.StreamPartTypeTextDelta,
466 ID: "0",
467 Delta: delta,
468 }) {
469 return
470 }
471 currentContent += delta
472 }
473 case part.FunctionCall != nil:
474 if isActiveText {
475 isActiveText = false
476 if !yield(ai.StreamPart{
477 Type: ai.StreamPartTypeTextEnd,
478 ID: "0",
479 }) {
480 return
481 }
482 }
483
484 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
485
486 args, err := json.Marshal(part.FunctionCall.Args)
487 if err != nil {
488 yield(ai.StreamPart{
489 Type: ai.StreamPartTypeError,
490 Error: err,
491 })
492 return
493 }
494
495 if !yield(ai.StreamPart{
496 Type: ai.StreamPartTypeToolInputStart,
497 ID: toolCallID,
498 ToolCallName: part.FunctionCall.Name,
499 }) {
500 return
501 }
502
503 if !yield(ai.StreamPart{
504 Type: ai.StreamPartTypeToolInputDelta,
505 ID: toolCallID,
506 Delta: string(args),
507 }) {
508 return
509 }
510
511 if !yield(ai.StreamPart{
512 Type: ai.StreamPartTypeToolInputEnd,
513 ID: toolCallID,
514 }) {
515 return
516 }
517
518 if !yield(ai.StreamPart{
519 Type: ai.StreamPartTypeToolCall,
520 ID: toolCallID,
521 ToolCallName: part.FunctionCall.Name,
522 ToolCallInput: string(args),
523 ProviderExecuted: false,
524 }) {
525 return
526 }
527
528 toolCalls = append(toolCalls, ai.ToolCallContent{
529 ToolCallID: toolCallID,
530 ToolName: part.FunctionCall.Name,
531 Input: string(args),
532 ProviderExecuted: false,
533 })
534 }
535 }
536 }
537
538 if resp.UsageMetadata != nil {
539 usage = mapUsage(resp.UsageMetadata)
540 }
541 }
542
543 if isActiveText {
544 if !yield(ai.StreamPart{
545 Type: ai.StreamPartTypeTextEnd,
546 ID: "0",
547 }) {
548 return
549 }
550 }
551
552 finishReason := ai.FinishReasonStop
553 if len(toolCalls) > 0 {
554 finishReason = ai.FinishReasonToolCalls
555 }
556
557 yield(ai.StreamPart{
558 Type: ai.StreamPartTypeFinish,
559 Usage: usage,
560 FinishReason: finishReason,
561 })
562 }, nil
563}
564
565func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
566 for _, tool := range tools {
567 if tool.GetType() == ai.ToolTypeFunction {
568 ft, ok := tool.(ai.FunctionTool)
569 if !ok {
570 continue
571 }
572
573 required := []string{}
574 var properties map[string]any
575 if props, ok := ft.InputSchema["properties"]; ok {
576 properties, _ = props.(map[string]any)
577 }
578 if req, ok := ft.InputSchema["required"]; ok {
579 if reqArr, ok := req.([]string); ok {
580 required = reqArr
581 }
582 }
583 declaration := &genai.FunctionDeclaration{
584 Name: ft.Name,
585 Description: ft.Description,
586 Parameters: &genai.Schema{
587 Type: genai.TypeObject,
588 Properties: convertSchemaProperties(properties),
589 Required: required,
590 },
591 }
592 googleTools = append(googleTools, declaration)
593 continue
594 }
595 // TODO: handle provider tool calls
596 warnings = append(warnings, ai.CallWarning{
597 Type: ai.CallWarningTypeUnsupportedTool,
598 Tool: tool,
599 Message: "tool is not supported",
600 })
601 }
602 if toolChoice == nil {
603 return
604 }
605 switch *toolChoice {
606 case ai.ToolChoiceAuto:
607 googleToolChoice = &genai.ToolConfig{
608 FunctionCallingConfig: &genai.FunctionCallingConfig{
609 Mode: genai.FunctionCallingConfigModeAuto,
610 },
611 }
612 case ai.ToolChoiceRequired:
613 googleToolChoice = &genai.ToolConfig{
614 FunctionCallingConfig: &genai.FunctionCallingConfig{
615 Mode: genai.FunctionCallingConfigModeAny,
616 },
617 }
618 case ai.ToolChoiceNone:
619 googleToolChoice = &genai.ToolConfig{
620 FunctionCallingConfig: &genai.FunctionCallingConfig{
621 Mode: genai.FunctionCallingConfigModeNone,
622 },
623 }
624 default:
625 googleToolChoice = &genai.ToolConfig{
626 FunctionCallingConfig: &genai.FunctionCallingConfig{
627 Mode: genai.FunctionCallingConfigModeAny,
628 AllowedFunctionNames: []string{
629 string(*toolChoice),
630 },
631 },
632 }
633 }
634 return
635}
636
637func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
638 properties := make(map[string]*genai.Schema)
639
640 for name, param := range parameters {
641 properties[name] = convertToSchema(param)
642 }
643
644 return properties
645}
646
647func convertToSchema(param any) *genai.Schema {
648 schema := &genai.Schema{Type: genai.TypeString}
649
650 paramMap, ok := param.(map[string]any)
651 if !ok {
652 return schema
653 }
654
655 if desc, ok := paramMap["description"].(string); ok {
656 schema.Description = desc
657 }
658
659 typeVal, hasType := paramMap["type"]
660 if !hasType {
661 return schema
662 }
663
664 typeStr, ok := typeVal.(string)
665 if !ok {
666 return schema
667 }
668
669 schema.Type = mapJSONTypeToGoogle(typeStr)
670
671 switch typeStr {
672 case "array":
673 schema.Items = processArrayItems(paramMap)
674 case "object":
675 if props, ok := paramMap["properties"].(map[string]any); ok {
676 schema.Properties = convertSchemaProperties(props)
677 }
678 }
679
680 return schema
681}
682
683func processArrayItems(paramMap map[string]any) *genai.Schema {
684 items, ok := paramMap["items"].(map[string]any)
685 if !ok {
686 return nil
687 }
688
689 return convertToSchema(items)
690}
691
692func mapJSONTypeToGoogle(jsonType string) genai.Type {
693 switch jsonType {
694 case "string":
695 return genai.TypeString
696 case "number":
697 return genai.TypeNumber
698 case "integer":
699 return genai.TypeInteger
700 case "boolean":
701 return genai.TypeBoolean
702 case "array":
703 return genai.TypeArray
704 case "object":
705 return genai.TypeObject
706 default:
707 return genai.TypeString // Default to string for unknown types
708 }
709}
710
711func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
712 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
713 return nil, errors.New("no response from model")
714 }
715
716 var (
717 content []ai.Content
718 finishReason ai.FinishReason
719 hasToolCalls bool
720 candidate = response.Candidates[0]
721 )
722
723 for _, part := range candidate.Content.Parts {
724 switch {
725 case part.Text != "":
726 content = append(content, ai.TextContent{Text: part.Text})
727 case part.FunctionCall != nil:
728 input, err := json.Marshal(part.FunctionCall.Args)
729 if err != nil {
730 return nil, err
731 }
732 content = append(content, ai.ToolCallContent{
733 ToolCallID: part.FunctionCall.ID,
734 ToolName: part.FunctionCall.Name,
735 Input: string(input),
736 ProviderExecuted: false,
737 })
738 hasToolCalls = true
739 default:
740 return nil, fmt.Errorf("not implemented part type")
741 }
742 }
743
744 if hasToolCalls {
745 finishReason = ai.FinishReasonToolCalls
746 } else {
747 finishReason = mapFinishReason(candidate.FinishReason)
748 }
749
750 return &ai.Response{
751 Content: content,
752 Usage: mapUsage(response.UsageMetadata),
753 FinishReason: finishReason,
754 Warnings: warnings,
755 }, nil
756}
757
758func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
759 switch reason {
760 case genai.FinishReasonStop:
761 return ai.FinishReasonStop
762 case genai.FinishReasonMaxTokens:
763 return ai.FinishReasonLength
764 case genai.FinishReasonSafety,
765 genai.FinishReasonBlocklist,
766 genai.FinishReasonProhibitedContent,
767 genai.FinishReasonSPII,
768 genai.FinishReasonImageSafety:
769 return ai.FinishReasonContentFilter
770 case genai.FinishReasonRecitation,
771 genai.FinishReasonLanguage,
772 genai.FinishReasonMalformedFunctionCall:
773 return ai.FinishReasonError
774 case genai.FinishReasonOther:
775 return ai.FinishReasonOther
776 default:
777 return ai.FinishReasonUnknown
778 }
779}
780
781func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
782 return ai.Usage{
783 InputTokens: int64(usage.ToolUsePromptTokenCount),
784 OutputTokens: int64(usage.CandidatesTokenCount),
785 TotalTokens: int64(usage.TotalTokenCount),
786 ReasoningTokens: int64(usage.ThoughtsTokenCount),
787 CacheCreationTokens: int64(usage.CachedContentTokenCount),
788 CacheReadTokens: 0,
789 }
790}