1package google
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "maps"
11 "net/http"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 "github.com/charmbracelet/x/exp/slice"
16 "github.com/google/uuid"
17 "google.golang.org/genai"
18)
19
20type provider struct {
21 options options
22}
23
24type options struct {
25 apiKey string
26 name string
27 headers map[string]string
28 client *http.Client
29}
30
31type Option = func(*options)
32
33func New(opts ...Option) ai.Provider {
34 options := options{
35 headers: map[string]string{},
36 }
37 for _, o := range opts {
38 o(&options)
39 }
40
41 if options.name == "" {
42 options.name = "google"
43 }
44
45 return &provider{
46 options: options,
47 }
48}
49
50func WithAPIKey(apiKey string) Option {
51 return func(o *options) {
52 o.apiKey = apiKey
53 }
54}
55
56func WithName(name string) Option {
57 return func(o *options) {
58 o.name = name
59 }
60}
61
62func WithHeaders(headers map[string]string) Option {
63 return func(o *options) {
64 maps.Copy(o.headers, headers)
65 }
66}
67
68func WithHTTPClient(client *http.Client) Option {
69 return func(o *options) {
70 o.client = client
71 }
72}
73
74type languageModel struct {
75 provider string
76 modelID string
77 client *genai.Client
78 providerOptions options
79}
80
81// LanguageModel implements ai.Provider.
82func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
83 cc := &genai.ClientConfig{
84 APIKey: g.options.apiKey,
85 Backend: genai.BackendGeminiAPI,
86 HTTPClient: g.options.client,
87 }
88 client, err := genai.NewClient(context.Background(), cc)
89 if err != nil {
90 return nil, err
91 }
92 return &languageModel{
93 modelID: modelID,
94 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
95 providerOptions: g.options,
96 client: client,
97 }, nil
98}
99
100func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
101 config := &genai.GenerateContentConfig{}
102 providerOptions := &providerOptions{}
103 if v, ok := call.ProviderOptions["google"]; ok {
104 err := ai.ParseOptions(v, providerOptions)
105 if err != nil {
106 return nil, nil, nil, err
107 }
108 }
109
110 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
111
112 if providerOptions.ThinkingConfig != nil &&
113 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
114 *providerOptions.ThinkingConfig.IncludeThoughts &&
115 strings.HasPrefix(a.provider, "google.vertex.") {
116 warnings = append(warnings, ai.CallWarning{
117 Type: ai.CallWarningTypeOther,
118 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
119 "and might not be supported or could behave unexpectedly with the current Google provider " +
120 fmt.Sprintf("(%s)", a.provider),
121 })
122 }
123
124 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
125
126 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
127 if len(content) > 0 && content[0].Role == genai.RoleUser {
128 systemParts := []string{}
129 for _, sp := range systemInstructions.Parts {
130 systemParts = append(systemParts, sp.Text)
131 }
132 systemMsg := strings.Join(systemParts, "\n")
133 content[0].Parts = append([]*genai.Part{
134 {
135 Text: systemMsg + "\n\n",
136 },
137 }, content[0].Parts...)
138 systemInstructions = nil
139 }
140 }
141
142 config.SystemInstruction = systemInstructions
143
144 if call.MaxOutputTokens != nil {
145 config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
146 }
147
148 if call.Temperature != nil {
149 tmp := float32(*call.Temperature)
150 config.Temperature = &tmp
151 }
152 if call.TopK != nil {
153 tmp := float32(*call.TopK)
154 config.TopK = &tmp
155 }
156 if call.TopP != nil {
157 tmp := float32(*call.TopP)
158 config.TopP = &tmp
159 }
160 if call.FrequencyPenalty != nil {
161 tmp := float32(*call.FrequencyPenalty)
162 config.FrequencyPenalty = &tmp
163 }
164 if call.PresencePenalty != nil {
165 tmp := float32(*call.PresencePenalty)
166 config.PresencePenalty = &tmp
167 }
168
169 if providerOptions.ThinkingConfig != nil {
170 config.ThinkingConfig = &genai.ThinkingConfig{}
171 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
172 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
173 }
174 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
175 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
176 config.ThinkingConfig.ThinkingBudget = &tmp
177 }
178 }
179 for _, safetySetting := range providerOptions.SafetySettings {
180 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
181 Category: genai.HarmCategory(safetySetting.Category),
182 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
183 })
184 }
185 if providerOptions.CachedContent != "" {
186 config.CachedContent = providerOptions.CachedContent
187 }
188
189 if len(call.Tools) > 0 {
190 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
191 config.ToolConfig = toolChoice
192 config.Tools = append(config.Tools, &genai.Tool{
193 FunctionDeclarations: tools,
194 })
195 warnings = append(warnings, toolWarnings...)
196 }
197
198 return config, content, warnings, nil
199}
200
201func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
202 var systemInstructions *genai.Content
203 var content []*genai.Content
204 var warnings []ai.CallWarning
205
206 finishedSystemBlock := false
207 for _, msg := range prompt {
208 switch msg.Role {
209 case ai.MessageRoleSystem:
210 if finishedSystemBlock {
211 // skip multiple system messages that are separated by user/assistant messages
212 // TODO: see if we need to send error here?
213 continue
214 }
215 finishedSystemBlock = true
216
217 var systemMessages []string
218 for _, part := range msg.Content {
219 text, ok := ai.AsMessagePart[ai.TextPart](part)
220 if !ok || text.Text == "" {
221 continue
222 }
223 systemMessages = append(systemMessages, text.Text)
224 }
225 if len(systemMessages) > 0 {
226 systemInstructions = &genai.Content{
227 Parts: []*genai.Part{
228 {
229 Text: strings.Join(systemMessages, "\n"),
230 },
231 },
232 }
233 }
234 case ai.MessageRoleUser:
235 var parts []*genai.Part
236 for _, part := range msg.Content {
237 switch part.GetType() {
238 case ai.ContentTypeText:
239 text, ok := ai.AsMessagePart[ai.TextPart](part)
240 if !ok || text.Text == "" {
241 continue
242 }
243 parts = append(parts, &genai.Part{
244 Text: text.Text,
245 })
246 case ai.ContentTypeFile:
247 file, ok := ai.AsMessagePart[ai.FilePart](part)
248 if !ok {
249 continue
250 }
251 var encoded []byte
252 base64.StdEncoding.Encode(encoded, file.Data)
253 parts = append(parts, &genai.Part{
254 InlineData: &genai.Blob{
255 Data: encoded,
256 MIMEType: file.MediaType,
257 },
258 })
259 }
260 }
261 if len(parts) > 0 {
262 content = append(content, &genai.Content{
263 Role: genai.RoleUser,
264 Parts: parts,
265 })
266 }
267 case ai.MessageRoleAssistant:
268 var parts []*genai.Part
269 for _, part := range msg.Content {
270 switch part.GetType() {
271 case ai.ContentTypeText:
272 text, ok := ai.AsMessagePart[ai.TextPart](part)
273 if !ok || text.Text == "" {
274 continue
275 }
276 parts = append(parts, &genai.Part{
277 Text: text.Text,
278 })
279 case ai.ContentTypeToolCall:
280 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
281 if !ok {
282 continue
283 }
284
285 var result map[string]any
286 err := json.Unmarshal([]byte(toolCall.Input), &result)
287 if err != nil {
288 continue
289 }
290 parts = append(parts, &genai.Part{
291 FunctionCall: &genai.FunctionCall{
292 ID: toolCall.ToolCallID,
293 Name: toolCall.ToolName,
294 Args: result,
295 },
296 })
297 }
298 }
299 if len(parts) > 0 {
300 content = append(content, &genai.Content{
301 Role: genai.RoleModel,
302 Parts: parts,
303 })
304 }
305 case ai.MessageRoleTool:
306 var parts []*genai.Part
307 for _, part := range msg.Content {
308 switch part.GetType() {
309 case ai.ContentTypeToolResult:
310 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
311 if !ok {
312 continue
313 }
314 var toolCall ai.ToolCallPart
315 for _, m := range prompt {
316 if m.Role == ai.MessageRoleAssistant {
317 for _, content := range m.Content {
318 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
319 if !ok {
320 continue
321 }
322 if tc.ToolCallID == result.ToolCallID {
323 toolCall = tc
324 break
325 }
326 }
327 }
328 }
329 switch result.Output.GetType() {
330 case ai.ToolResultContentTypeText:
331 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
332 if !ok {
333 continue
334 }
335 response := map[string]any{"result": content.Text}
336 parts = append(parts, &genai.Part{
337 FunctionResponse: &genai.FunctionResponse{
338 ID: result.ToolCallID,
339 Response: response,
340 Name: toolCall.ToolName,
341 },
342 })
343
344 case ai.ToolResultContentTypeError:
345 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
346 if !ok {
347 continue
348 }
349 response := map[string]any{"result": content.Error.Error()}
350 parts = append(parts, &genai.Part{
351 FunctionResponse: &genai.FunctionResponse{
352 ID: result.ToolCallID,
353 Response: response,
354 Name: toolCall.ToolName,
355 },
356 })
357 }
358 }
359 }
360 if len(parts) > 0 {
361 content = append(content, &genai.Content{
362 Role: genai.RoleUser,
363 Parts: parts,
364 })
365 }
366 }
367 }
368 return systemInstructions, content, warnings
369}
370
371// Generate implements ai.LanguageModel.
372func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
373 config, contents, warnings, err := g.prepareParams(call)
374 if err != nil {
375 return nil, err
376 }
377
378 lastMessage, history, ok := slice.Pop(contents)
379 if !ok {
380 return nil, errors.New("no messages to send")
381 }
382
383 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
384 if err != nil {
385 return nil, err
386 }
387
388 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
389 if err != nil {
390 return nil, err
391 }
392
393 return mapResponse(response, warnings)
394}
395
396// Model implements ai.LanguageModel.
397func (g *languageModel) Model() string {
398 return g.modelID
399}
400
401// Provider implements ai.LanguageModel.
402func (g *languageModel) Provider() string {
403 return g.provider
404}
405
406// Stream implements ai.LanguageModel.
407func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
408 config, contents, warnings, err := g.prepareParams(call)
409 if err != nil {
410 return nil, err
411 }
412
413 lastMessage, history, ok := slice.Pop(contents)
414 if !ok {
415 return nil, errors.New("no messages to send")
416 }
417
418 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
419 if err != nil {
420 return nil, err
421 }
422
423 return func(yield func(ai.StreamPart) bool) {
424 if len(warnings) > 0 {
425 if !yield(ai.StreamPart{
426 Type: ai.StreamPartTypeWarnings,
427 Warnings: warnings,
428 }) {
429 return
430 }
431 }
432
433 var currentContent string
434 var toolCalls []ai.ToolCallContent
435 var isActiveText bool
436 var usage ai.Usage
437
438 // Stream the response
439 for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
440 if err != nil {
441 yield(ai.StreamPart{
442 Type: ai.StreamPartTypeError,
443 Error: err,
444 })
445 return
446 }
447
448 if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
449 for _, part := range resp.Candidates[0].Content.Parts {
450 switch {
451 case part.Text != "":
452 delta := part.Text
453 if delta != "" {
454 if !isActiveText {
455 isActiveText = true
456 if !yield(ai.StreamPart{
457 Type: ai.StreamPartTypeTextStart,
458 ID: "0",
459 }) {
460 return
461 }
462 }
463 if !yield(ai.StreamPart{
464 Type: ai.StreamPartTypeTextDelta,
465 ID: "0",
466 Delta: delta,
467 }) {
468 return
469 }
470 currentContent += delta
471 }
472 case part.FunctionCall != nil:
473 if isActiveText {
474 isActiveText = false
475 if !yield(ai.StreamPart{
476 Type: ai.StreamPartTypeTextEnd,
477 ID: "0",
478 }) {
479 return
480 }
481 }
482
483 toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
484
485 args, err := json.Marshal(part.FunctionCall.Args)
486 if err != nil {
487 yield(ai.StreamPart{
488 Type: ai.StreamPartTypeError,
489 Error: err,
490 })
491 return
492 }
493
494 if !yield(ai.StreamPart{
495 Type: ai.StreamPartTypeToolInputStart,
496 ID: toolCallID,
497 ToolCallName: part.FunctionCall.Name,
498 }) {
499 return
500 }
501
502 if !yield(ai.StreamPart{
503 Type: ai.StreamPartTypeToolInputDelta,
504 ID: toolCallID,
505 Delta: string(args),
506 }) {
507 return
508 }
509
510 if !yield(ai.StreamPart{
511 Type: ai.StreamPartTypeToolInputEnd,
512 ID: toolCallID,
513 }) {
514 return
515 }
516
517 if !yield(ai.StreamPart{
518 Type: ai.StreamPartTypeToolCall,
519 ID: toolCallID,
520 ToolCallName: part.FunctionCall.Name,
521 ToolCallInput: string(args),
522 ProviderExecuted: false,
523 }) {
524 return
525 }
526
527 toolCalls = append(toolCalls, ai.ToolCallContent{
528 ToolCallID: toolCallID,
529 ToolName: part.FunctionCall.Name,
530 Input: string(args),
531 ProviderExecuted: false,
532 })
533 }
534 }
535 }
536
537 if resp.UsageMetadata != nil {
538 usage = mapUsage(resp.UsageMetadata)
539 }
540 }
541
542 if isActiveText {
543 if !yield(ai.StreamPart{
544 Type: ai.StreamPartTypeTextEnd,
545 ID: "0",
546 }) {
547 return
548 }
549 }
550
551 finishReason := ai.FinishReasonStop
552 if len(toolCalls) > 0 {
553 finishReason = ai.FinishReasonToolCalls
554 }
555
556 yield(ai.StreamPart{
557 Type: ai.StreamPartTypeFinish,
558 Usage: usage,
559 FinishReason: finishReason,
560 })
561 }, nil
562}
563
564func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
565 for _, tool := range tools {
566 if tool.GetType() == ai.ToolTypeFunction {
567 ft, ok := tool.(ai.FunctionTool)
568 if !ok {
569 continue
570 }
571
572 required := []string{}
573 var properties map[string]any
574 if props, ok := ft.InputSchema["properties"]; ok {
575 properties, _ = props.(map[string]any)
576 }
577 if req, ok := ft.InputSchema["required"]; ok {
578 if reqArr, ok := req.([]string); ok {
579 required = reqArr
580 }
581 }
582 declaration := &genai.FunctionDeclaration{
583 Name: ft.Name,
584 Description: ft.Description,
585 Parameters: &genai.Schema{
586 Type: genai.TypeObject,
587 Properties: convertSchemaProperties(properties),
588 Required: required,
589 },
590 }
591 googleTools = append(googleTools, declaration)
592 continue
593 }
594 // TODO: handle provider tool calls
595 warnings = append(warnings, ai.CallWarning{
596 Type: ai.CallWarningTypeUnsupportedTool,
597 Tool: tool,
598 Message: "tool is not supported",
599 })
600 }
601 if toolChoice == nil {
602 return //nolint: nakedret
603 }
604 switch *toolChoice {
605 case ai.ToolChoiceAuto:
606 googleToolChoice = &genai.ToolConfig{
607 FunctionCallingConfig: &genai.FunctionCallingConfig{
608 Mode: genai.FunctionCallingConfigModeAuto,
609 },
610 }
611 case ai.ToolChoiceRequired:
612 googleToolChoice = &genai.ToolConfig{
613 FunctionCallingConfig: &genai.FunctionCallingConfig{
614 Mode: genai.FunctionCallingConfigModeAny,
615 },
616 }
617 case ai.ToolChoiceNone:
618 googleToolChoice = &genai.ToolConfig{
619 FunctionCallingConfig: &genai.FunctionCallingConfig{
620 Mode: genai.FunctionCallingConfigModeNone,
621 },
622 }
623 default:
624 googleToolChoice = &genai.ToolConfig{
625 FunctionCallingConfig: &genai.FunctionCallingConfig{
626 Mode: genai.FunctionCallingConfigModeAny,
627 AllowedFunctionNames: []string{
628 string(*toolChoice),
629 },
630 },
631 }
632 }
633 return //nolint: nakedret
634}
635
636func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
637 properties := make(map[string]*genai.Schema)
638
639 for name, param := range parameters {
640 properties[name] = convertToSchema(param)
641 }
642
643 return properties
644}
645
646func convertToSchema(param any) *genai.Schema {
647 schema := &genai.Schema{Type: genai.TypeString}
648
649 paramMap, ok := param.(map[string]any)
650 if !ok {
651 return schema
652 }
653
654 if desc, ok := paramMap["description"].(string); ok {
655 schema.Description = desc
656 }
657
658 typeVal, hasType := paramMap["type"]
659 if !hasType {
660 return schema
661 }
662
663 typeStr, ok := typeVal.(string)
664 if !ok {
665 return schema
666 }
667
668 schema.Type = mapJSONTypeToGoogle(typeStr)
669
670 switch typeStr {
671 case "array":
672 schema.Items = processArrayItems(paramMap)
673 case "object":
674 if props, ok := paramMap["properties"].(map[string]any); ok {
675 schema.Properties = convertSchemaProperties(props)
676 }
677 }
678
679 return schema
680}
681
682func processArrayItems(paramMap map[string]any) *genai.Schema {
683 items, ok := paramMap["items"].(map[string]any)
684 if !ok {
685 return nil
686 }
687
688 return convertToSchema(items)
689}
690
691func mapJSONTypeToGoogle(jsonType string) genai.Type {
692 switch jsonType {
693 case "string":
694 return genai.TypeString
695 case "number":
696 return genai.TypeNumber
697 case "integer":
698 return genai.TypeInteger
699 case "boolean":
700 return genai.TypeBoolean
701 case "array":
702 return genai.TypeArray
703 case "object":
704 return genai.TypeObject
705 default:
706 return genai.TypeString // Default to string for unknown types
707 }
708}
709
710func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
711 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
712 return nil, errors.New("no response from model")
713 }
714
715 var (
716 content []ai.Content
717 finishReason ai.FinishReason
718 hasToolCalls bool
719 candidate = response.Candidates[0]
720 )
721
722 for _, part := range candidate.Content.Parts {
723 switch {
724 case part.Text != "":
725 content = append(content, ai.TextContent{Text: part.Text})
726 case part.FunctionCall != nil:
727 input, err := json.Marshal(part.FunctionCall.Args)
728 if err != nil {
729 return nil, err
730 }
731 content = append(content, ai.ToolCallContent{
732 ToolCallID: part.FunctionCall.ID,
733 ToolName: part.FunctionCall.Name,
734 Input: string(input),
735 ProviderExecuted: false,
736 })
737 hasToolCalls = true
738 default:
739 return nil, fmt.Errorf("not implemented part type")
740 }
741 }
742
743 if hasToolCalls {
744 finishReason = ai.FinishReasonToolCalls
745 } else {
746 finishReason = mapFinishReason(candidate.FinishReason)
747 }
748
749 return &ai.Response{
750 Content: content,
751 Usage: mapUsage(response.UsageMetadata),
752 FinishReason: finishReason,
753 Warnings: warnings,
754 }, nil
755}
756
757func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
758 switch reason {
759 case genai.FinishReasonStop:
760 return ai.FinishReasonStop
761 case genai.FinishReasonMaxTokens:
762 return ai.FinishReasonLength
763 case genai.FinishReasonSafety,
764 genai.FinishReasonBlocklist,
765 genai.FinishReasonProhibitedContent,
766 genai.FinishReasonSPII,
767 genai.FinishReasonImageSafety:
768 return ai.FinishReasonContentFilter
769 case genai.FinishReasonRecitation,
770 genai.FinishReasonLanguage,
771 genai.FinishReasonMalformedFunctionCall:
772 return ai.FinishReasonError
773 case genai.FinishReasonOther:
774 return ai.FinishReasonOther
775 default:
776 return ai.FinishReasonUnknown
777 }
778}
779
780func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
781 return ai.Usage{
782 InputTokens: int64(usage.ToolUsePromptTokenCount),
783 OutputTokens: int64(usage.CandidatesTokenCount),
784 TotalTokens: int64(usage.TotalTokenCount),
785 ReasoningTokens: int64(usage.ThoughtsTokenCount),
786 CacheCreationTokens: int64(usage.CachedContentTokenCount),
787 CacheReadTokens: 0,
788 }
789}