1package google
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "github.com/charmbracelet/fantasy/ai"
14 "github.com/charmbracelet/x/exp/slice"
15 "google.golang.org/genai"
16)
17
18type provider struct {
19 options options
20}
21
22type options struct {
23 apiKey string
24 name string
25 headers map[string]string
26 client *http.Client
27}
28
29type Option = func(*options)
30
31func New(opts ...Option) ai.Provider {
32 options := options{
33 headers: map[string]string{},
34 }
35 for _, o := range opts {
36 o(&options)
37 }
38
39 if options.name == "" {
40 options.name = "google"
41 }
42
43 return &provider{
44 options: options,
45 }
46}
47
48func WithAPIKey(apiKey string) Option {
49 return func(o *options) {
50 o.apiKey = apiKey
51 }
52}
53
54func WithName(name string) Option {
55 return func(o *options) {
56 o.name = name
57 }
58}
59
60func WithHeaders(headers map[string]string) Option {
61 return func(o *options) {
62 maps.Copy(o.headers, headers)
63 }
64}
65
66func WithHTTPClient(client *http.Client) Option {
67 return func(o *options) {
68 o.client = client
69 }
70}
71
72type languageModel struct {
73 provider string
74 modelID string
75 client *genai.Client
76 providerOptions options
77}
78
79// LanguageModel implements ai.Provider.
80func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
81 cc := &genai.ClientConfig{
82 APIKey: g.options.apiKey,
83 Backend: genai.BackendGeminiAPI,
84 HTTPClient: g.options.client,
85 }
86 client, err := genai.NewClient(context.Background(), cc)
87 if err != nil {
88 return nil, err
89 }
90 return &languageModel{
91 modelID: modelID,
92 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
93 providerOptions: g.options,
94 client: client,
95 }, nil
96}
97
98func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
99 config := &genai.GenerateContentConfig{}
100 providerOptions := &providerOptions{}
101 if v, ok := call.ProviderOptions["google"]; ok {
102 err := ai.ParseOptions(v, providerOptions)
103 if err != nil {
104 return nil, nil, nil, err
105 }
106 }
107
108 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
109
110 if providerOptions.ThinkingConfig != nil &&
111 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
112 *providerOptions.ThinkingConfig.IncludeThoughts &&
113 strings.HasPrefix(a.provider, "google.vertex.") {
114 warnings = append(warnings, ai.CallWarning{
115 Type: ai.CallWarningTypeOther,
116 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
117 "and might not be supported or could behave unexpectedly with the current Google provider " +
118 fmt.Sprintf("(%s)", a.provider),
119 })
120 }
121
122 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
123
124 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
125 if len(content) > 0 && content[0].Role == genai.RoleUser {
126 systemParts := []string{}
127 for _, sp := range systemInstructions.Parts {
128 systemParts = append(systemParts, sp.Text)
129 }
130 systemMsg := strings.Join(systemParts, "\n")
131 content[0].Parts = append([]*genai.Part{
132 {
133 Text: systemMsg + "\n\n",
134 },
135 }, content[0].Parts...)
136 systemInstructions = nil
137 }
138 }
139
140 config.SystemInstruction = systemInstructions
141
142 if call.MaxOutputTokens != nil {
143 config.MaxOutputTokens = int32(*call.MaxOutputTokens)
144 }
145
146 if call.Temperature != nil {
147 tmp := float32(*call.Temperature)
148 config.Temperature = &tmp
149 }
150 if call.TopK != nil {
151 tmp := float32(*call.TopK)
152 config.TopK = &tmp
153 }
154 if call.TopP != nil {
155 tmp := float32(*call.TopP)
156 config.TopP = &tmp
157 }
158 if call.FrequencyPenalty != nil {
159 tmp := float32(*call.FrequencyPenalty)
160 config.FrequencyPenalty = &tmp
161 }
162 if call.PresencePenalty != nil {
163 tmp := float32(*call.PresencePenalty)
164 config.PresencePenalty = &tmp
165 }
166
167 if providerOptions.ThinkingConfig != nil {
168 config.ThinkingConfig = &genai.ThinkingConfig{}
169 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
170 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
171 }
172 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
173 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
174 config.ThinkingConfig.ThinkingBudget = &tmp
175 }
176 }
177 for _, safetySetting := range providerOptions.SafetySettings {
178 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
179 Category: genai.HarmCategory(safetySetting.Category),
180 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
181 })
182 }
183 if providerOptions.CachedContent != "" {
184 config.CachedContent = providerOptions.CachedContent
185 }
186
187 if len(call.Tools) > 0 {
188 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
189 config.ToolConfig = toolChoice
190 config.Tools = append(config.Tools, &genai.Tool{
191 FunctionDeclarations: tools,
192 })
193 warnings = append(warnings, toolWarnings...)
194 }
195
196 return config, content, warnings, nil
197}
198
199func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
200 var systemInstructions *genai.Content
201 var content []*genai.Content
202 var warnings []ai.CallWarning
203
204 finishedSystemBlock := false
205 for _, msg := range prompt {
206 switch msg.Role {
207 case ai.MessageRoleSystem:
208 if finishedSystemBlock {
209 // skip multiple system messages that are separated by user/assistant messages
210 // TODO: see if we need to send error here?
211 continue
212 }
213 finishedSystemBlock = true
214
215 var systemMessages []string
216 for _, part := range msg.Content {
217 text, ok := ai.AsMessagePart[ai.TextPart](part)
218 if !ok || text.Text == "" {
219 continue
220 }
221 systemMessages = append(systemMessages, text.Text)
222 }
223 if len(systemMessages) > 0 {
224 systemInstructions = &genai.Content{
225 Parts: []*genai.Part{
226 {
227 Text: strings.Join(systemMessages, "\n"),
228 },
229 },
230 }
231 }
232 case ai.MessageRoleUser:
233 var parts []*genai.Part
234 for _, part := range msg.Content {
235 switch part.GetType() {
236 case ai.ContentTypeText:
237 text, ok := ai.AsMessagePart[ai.TextPart](part)
238 if !ok || text.Text == "" {
239 continue
240 }
241 parts = append(parts, &genai.Part{
242 Text: text.Text,
243 })
244 case ai.ContentTypeFile:
245 file, ok := ai.AsMessagePart[ai.FilePart](part)
246 if !ok {
247 continue
248 }
249 var encoded []byte
250 base64.StdEncoding.Encode(encoded, file.Data)
251 parts = append(parts, &genai.Part{
252 InlineData: &genai.Blob{
253 Data: encoded,
254 MIMEType: file.MediaType,
255 },
256 })
257 }
258 }
259 if len(parts) > 0 {
260 content = append(content, &genai.Content{
261 Role: genai.RoleUser,
262 Parts: parts,
263 })
264 }
265 case ai.MessageRoleAssistant:
266 var parts []*genai.Part
267 for _, part := range msg.Content {
268 switch part.GetType() {
269 case ai.ContentTypeText:
270 text, ok := ai.AsMessagePart[ai.TextPart](part)
271 if !ok || text.Text == "" {
272 continue
273 }
274 parts = append(parts, &genai.Part{
275 Text: text.Text,
276 })
277 case ai.ContentTypeToolCall:
278 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
279 if !ok {
280 continue
281 }
282
283 var result map[string]any
284 err := json.Unmarshal([]byte(toolCall.Input), &result)
285 if err != nil {
286 continue
287 }
288 parts = append(parts, &genai.Part{
289 FunctionCall: &genai.FunctionCall{
290 ID: toolCall.ToolCallID,
291 Name: toolCall.ToolName,
292 Args: result,
293 },
294 })
295 }
296 }
297 if len(parts) > 0 {
298 content = append(content, &genai.Content{
299 Role: genai.RoleModel,
300 Parts: parts,
301 })
302 }
303 case ai.MessageRoleTool:
304 var parts []*genai.Part
305 for _, part := range msg.Content {
306 switch part.GetType() {
307 case ai.ContentTypeToolResult:
308 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
309 if !ok {
310 continue
311 }
312 var toolCall ai.ToolCallPart
313 for _, m := range prompt {
314 if m.Role == ai.MessageRoleAssistant {
315 for _, content := range m.Content {
316 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
317 if !ok {
318 continue
319 }
320 if tc.ToolCallID == result.ToolCallID {
321 toolCall = tc
322 break
323 }
324 }
325 }
326 }
327 switch result.Output.GetType() {
328 case ai.ToolResultContentTypeText:
329 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
330 if !ok {
331 continue
332 }
333 response := map[string]any{"result": content.Text}
334 parts = append(parts, &genai.Part{
335 FunctionResponse: &genai.FunctionResponse{
336 ID: result.ToolCallID,
337 Response: response,
338 Name: toolCall.ToolName,
339 },
340 })
341
342 case ai.ToolResultContentTypeError:
343 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
344 if !ok {
345 continue
346 }
347 response := map[string]any{"result": content.Error.Error()}
348 parts = append(parts, &genai.Part{
349 FunctionResponse: &genai.FunctionResponse{
350 ID: result.ToolCallID,
351 Response: response,
352 Name: toolCall.ToolName,
353 },
354 })
355
356 }
357 }
358 }
359 if len(parts) > 0 {
360 content = append(content, &genai.Content{
361 Role: genai.RoleUser,
362 Parts: parts,
363 })
364 }
365 }
366 }
367 return systemInstructions, content, warnings
368}
369
370// Generate implements ai.LanguageModel.
371func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
372 config, contents, warnings, err := g.prepareParams(call)
373 if err != nil {
374 return nil, err
375 }
376
377 lastMessage, history, ok := slice.Pop(contents)
378 if !ok {
379 return nil, errors.New("no messages to send")
380 }
381
382 chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
383 if err != nil {
384 return nil, err
385 }
386
387 response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
388 if err != nil {
389 return nil, err
390 }
391
392 return mapResponse(response, warnings)
393}
394
395// Model implements ai.LanguageModel.
396func (g *languageModel) Model() string {
397 return g.modelID
398}
399
400// Provider implements ai.LanguageModel.
401func (g *languageModel) Provider() string {
402 return g.provider
403}
404
405// Stream implements ai.LanguageModel.
406func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
407 return nil, errors.New("unimplemented")
408}
409
410func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
411 for _, tool := range tools {
412 if tool.GetType() == ai.ToolTypeFunction {
413 ft, ok := tool.(ai.FunctionTool)
414 if !ok {
415 continue
416 }
417
418 required := []string{}
419 var properties map[string]any
420 if props, ok := ft.InputSchema["properties"]; ok {
421 properties, _ = props.(map[string]any)
422 }
423 if req, ok := ft.InputSchema["required"]; ok {
424 if reqArr, ok := req.([]string); ok {
425 required = reqArr
426 }
427 }
428 declaration := &genai.FunctionDeclaration{
429 Name: ft.Name,
430 Description: ft.Description,
431 Parameters: &genai.Schema{
432 Type: genai.TypeObject,
433 Properties: convertSchemaProperties(properties),
434 Required: required,
435 },
436 }
437 googleTools = append(googleTools, declaration)
438 continue
439 }
440 // TODO: handle provider tool calls
441 warnings = append(warnings, ai.CallWarning{
442 Type: ai.CallWarningTypeUnsupportedTool,
443 Tool: tool,
444 Message: "tool is not supported",
445 })
446 }
447 if toolChoice == nil {
448 return
449 }
450 switch *toolChoice {
451 case ai.ToolChoiceAuto:
452 googleToolChoice = &genai.ToolConfig{
453 FunctionCallingConfig: &genai.FunctionCallingConfig{
454 Mode: genai.FunctionCallingConfigModeAuto,
455 },
456 }
457 case ai.ToolChoiceRequired:
458 googleToolChoice = &genai.ToolConfig{
459 FunctionCallingConfig: &genai.FunctionCallingConfig{
460 Mode: genai.FunctionCallingConfigModeAny,
461 },
462 }
463 case ai.ToolChoiceNone:
464 googleToolChoice = &genai.ToolConfig{
465 FunctionCallingConfig: &genai.FunctionCallingConfig{
466 Mode: genai.FunctionCallingConfigModeNone,
467 },
468 }
469 default:
470 googleToolChoice = &genai.ToolConfig{
471 FunctionCallingConfig: &genai.FunctionCallingConfig{
472 Mode: genai.FunctionCallingConfigModeAny,
473 AllowedFunctionNames: []string{
474 string(*toolChoice),
475 },
476 },
477 }
478 }
479 return
480}
481
482func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
483 properties := make(map[string]*genai.Schema)
484
485 for name, param := range parameters {
486 properties[name] = convertToSchema(param)
487 }
488
489 return properties
490}
491
492func convertToSchema(param any) *genai.Schema {
493 schema := &genai.Schema{Type: genai.TypeString}
494
495 paramMap, ok := param.(map[string]any)
496 if !ok {
497 return schema
498 }
499
500 if desc, ok := paramMap["description"].(string); ok {
501 schema.Description = desc
502 }
503
504 typeVal, hasType := paramMap["type"]
505 if !hasType {
506 return schema
507 }
508
509 typeStr, ok := typeVal.(string)
510 if !ok {
511 return schema
512 }
513
514 schema.Type = mapJSONTypeToGoogle(typeStr)
515
516 switch typeStr {
517 case "array":
518 schema.Items = processArrayItems(paramMap)
519 case "object":
520 if props, ok := paramMap["properties"].(map[string]any); ok {
521 schema.Properties = convertSchemaProperties(props)
522 }
523 }
524
525 return schema
526}
527
528func processArrayItems(paramMap map[string]any) *genai.Schema {
529 items, ok := paramMap["items"].(map[string]any)
530 if !ok {
531 return nil
532 }
533
534 return convertToSchema(items)
535}
536
537func mapJSONTypeToGoogle(jsonType string) genai.Type {
538 switch jsonType {
539 case "string":
540 return genai.TypeString
541 case "number":
542 return genai.TypeNumber
543 case "integer":
544 return genai.TypeInteger
545 case "boolean":
546 return genai.TypeBoolean
547 case "array":
548 return genai.TypeArray
549 case "object":
550 return genai.TypeObject
551 default:
552 return genai.TypeString // Default to string for unknown types
553 }
554}
555
556func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
557 if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
558 return nil, errors.New("no response from model")
559 }
560
561 var (
562 content []ai.Content
563 finishReason ai.FinishReason
564 hasToolCalls bool
565 candidate = response.Candidates[0]
566 )
567
568 for _, part := range candidate.Content.Parts {
569 switch {
570 case part.Text != "":
571 content = append(content, ai.TextContent{Text: part.Text})
572 case part.FunctionCall != nil:
573 input, err := json.Marshal(part.FunctionCall.Args)
574 if err != nil {
575 return nil, err
576 }
577 content = append(content, ai.ToolCallContent{
578 ToolCallID: part.FunctionCall.ID,
579 ToolName: part.FunctionCall.Name,
580 Input: string(input),
581 ProviderExecuted: false,
582 })
583 hasToolCalls = true
584 default:
585 return nil, fmt.Errorf("not implemented part type")
586 }
587 }
588
589 if hasToolCalls {
590 finishReason = ai.FinishReasonToolCalls
591 } else {
592 finishReason = mapFinishReason(candidate.FinishReason)
593 }
594
595 return &ai.Response{
596 Content: content,
597 Usage: mapUsage(response.UsageMetadata),
598 FinishReason: finishReason,
599 Warnings: warnings,
600 }, nil
601}
602
603func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
604 switch reason {
605 case genai.FinishReasonStop:
606 return ai.FinishReasonStop
607 case genai.FinishReasonMaxTokens:
608 return ai.FinishReasonLength
609 case genai.FinishReasonSafety,
610 genai.FinishReasonBlocklist,
611 genai.FinishReasonProhibitedContent,
612 genai.FinishReasonSPII,
613 genai.FinishReasonImageSafety:
614 return ai.FinishReasonContentFilter
615 case genai.FinishReasonRecitation,
616 genai.FinishReasonLanguage,
617 genai.FinishReasonMalformedFunctionCall:
618 return ai.FinishReasonError
619 case genai.FinishReasonOther:
620 return ai.FinishReasonOther
621 default:
622 return ai.FinishReasonUnknown
623 }
624}
625
626func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
627 return ai.Usage{
628 InputTokens: int64(usage.ToolUsePromptTokenCount),
629 OutputTokens: int64(usage.CandidatesTokenCount),
630 TotalTokens: int64(usage.TotalTokenCount),
631 ReasoningTokens: int64(usage.ThoughtsTokenCount),
632 CacheCreationTokens: int64(usage.CachedContentTokenCount),
633 CacheReadTokens: 0,
634 }
635}