1package google
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "fmt"
8 "maps"
9 "net/http"
10 "strings"
11
12 "github.com/charmbracelet/fantasy/ai"
13 "google.golang.org/genai"
14)
15
16type provider struct {
17 options options
18}
19type options struct {
20 apiKey string
21 name string
22 headers map[string]string
23 client *http.Client
24}
25
26type Option = func(*options)
27
28func New(opts ...Option) ai.Provider {
29 options := options{
30 headers: map[string]string{},
31 }
32 for _, o := range opts {
33 o(&options)
34 }
35
36 if options.name == "" {
37 options.name = "google"
38 }
39
40 return &provider{
41 options: options,
42 }
43}
44
45func WithAPIKey(apiKey string) Option {
46 return func(o *options) {
47 o.apiKey = apiKey
48 }
49}
50
51func WithName(name string) Option {
52 return func(o *options) {
53 o.name = name
54 }
55}
56
57func WithHeaders(headers map[string]string) Option {
58 return func(o *options) {
59 maps.Copy(o.headers, headers)
60 }
61}
62
63func WithHTTPClient(client *http.Client) Option {
64 return func(o *options) {
65 o.client = client
66 }
67}
68
69type languageModel struct {
70 provider string
71 modelID string
72 client *genai.Client
73 providerOptions options
74}
75
76// LanguageModel implements ai.Provider.
77func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
78 cc := &genai.ClientConfig{
79 APIKey: g.options.apiKey,
80 Backend: genai.BackendGeminiAPI,
81 HTTPClient: g.options.client,
82 }
83 client, err := genai.NewClient(context.Background(), cc)
84 if err != nil {
85 return nil, err
86 }
87 return &languageModel{
88 modelID: modelID,
89 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
90 providerOptions: g.options,
91 client: client,
92 }, nil
93}
94
95func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
96 config := &genai.GenerateContentConfig{}
97 providerOptions := &providerOptions{}
98 if v, ok := call.ProviderOptions["google"]; ok {
99 err := ai.ParseOptions(v, providerOptions)
100 if err != nil {
101 return nil, nil, nil, err
102 }
103 }
104
105 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
106
107 if providerOptions.ThinkingConfig != nil &&
108 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
109 *providerOptions.ThinkingConfig.IncludeThoughts &&
110 strings.HasPrefix(a.provider, "google.vertex.") {
111 warnings = append(warnings, ai.CallWarning{
112 Type: ai.CallWarningTypeOther,
113 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
114 "and might not be supported or could behave unexpectedly with the current Google provider " +
115 fmt.Sprintf("(%s)", a.provider),
116 })
117 }
118
119 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
120
121 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
122 if len(content) > 0 && content[0].Role == genai.RoleUser {
123 systemParts := []string{}
124 for _, sp := range systemInstructions.Parts {
125 systemParts = append(systemParts, sp.Text)
126 }
127 systemMsg := strings.Join(systemParts, "\n")
128 content[0].Parts = append([]*genai.Part{
129 {
130 Text: systemMsg + "\n\n",
131 },
132 }, content[0].Parts...)
133 systemInstructions = nil
134 }
135 }
136
137 config.SystemInstruction = systemInstructions
138
139 if call.MaxOutputTokens != nil {
140 config.MaxOutputTokens = int32(*call.MaxOutputTokens)
141 }
142
143 if call.Temperature != nil {
144 tmp := float32(*call.Temperature)
145 config.Temperature = &tmp
146 }
147 if call.TopK != nil {
148 tmp := float32(*call.TopK)
149 config.TopK = &tmp
150 }
151 if call.TopP != nil {
152 tmp := float32(*call.TopP)
153 config.TopP = &tmp
154 }
155 if call.FrequencyPenalty != nil {
156 tmp := float32(*call.FrequencyPenalty)
157 config.FrequencyPenalty = &tmp
158 }
159 if call.PresencePenalty != nil {
160 tmp := float32(*call.PresencePenalty)
161 config.PresencePenalty = &tmp
162 }
163
164 if providerOptions.ThinkingConfig != nil {
165 config.ThinkingConfig = &genai.ThinkingConfig{}
166 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
167 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
168 }
169 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
170 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
171 config.ThinkingConfig.ThinkingBudget = &tmp
172 }
173 }
174 for _, safetySetting := range providerOptions.SafetySettings {
175 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
176 Category: genai.HarmCategory(safetySetting.Category),
177 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
178 })
179 }
180 if providerOptions.CachedContent != "" {
181 config.CachedContent = providerOptions.CachedContent
182 }
183
184 if len(call.Tools) > 0 {
185 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
186 config.ToolConfig = toolChoice
187 config.Tools = append(config.Tools, &genai.Tool{
188 FunctionDeclarations: tools,
189 })
190 warnings = append(warnings, toolWarnings...)
191 }
192
193 return config, content, warnings, nil
194}
195
196func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
197 var systemInstructions *genai.Content
198 var content []*genai.Content
199 var warnings []ai.CallWarning
200
201 finishedSystemBlock := false
202 for _, msg := range prompt {
203 switch msg.Role {
204 case ai.MessageRoleSystem:
205 if finishedSystemBlock {
206 // skip multiple system messages that are separated by user/assistant messages
207 // TODO: see if we need to send error here?
208 continue
209 }
210 finishedSystemBlock = true
211
212 var systemMessages []string
213 for _, part := range msg.Content {
214 text, ok := ai.AsMessagePart[ai.TextPart](part)
215 if !ok || text.Text == "" {
216 continue
217 }
218 systemMessages = append(systemMessages, text.Text)
219 }
220 if len(systemMessages) > 0 {
221 systemInstructions = &genai.Content{
222 Parts: []*genai.Part{
223 {
224 Text: strings.Join(systemMessages, "\n"),
225 },
226 },
227 }
228 }
229 case ai.MessageRoleUser:
230 var parts []*genai.Part
231 for _, part := range msg.Content {
232 switch part.GetType() {
233 case ai.ContentTypeText:
234 text, ok := ai.AsMessagePart[ai.TextPart](part)
235 if !ok || text.Text == "" {
236 continue
237 }
238 parts = append(parts, &genai.Part{
239 Text: text.Text,
240 })
241 case ai.ContentTypeFile:
242 file, ok := ai.AsMessagePart[ai.FilePart](part)
243 if !ok {
244 continue
245 }
246 var encoded []byte
247 base64.StdEncoding.Encode(encoded, file.Data)
248 parts = append(parts, &genai.Part{
249 InlineData: &genai.Blob{
250 Data: encoded,
251 MIMEType: file.MediaType,
252 },
253 })
254 }
255 }
256 if len(parts) > 0 {
257 content = append(content, &genai.Content{
258 Role: genai.RoleUser,
259 Parts: parts,
260 })
261 }
262 case ai.MessageRoleAssistant:
263 var parts []*genai.Part
264 for _, part := range msg.Content {
265 switch part.GetType() {
266 case ai.ContentTypeText:
267 text, ok := ai.AsMessagePart[ai.TextPart](part)
268 if !ok || text.Text == "" {
269 continue
270 }
271 parts = append(parts, &genai.Part{
272 Text: text.Text,
273 })
274 case ai.ContentTypeToolCall:
275 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
276 if !ok {
277 continue
278 }
279
280 var result map[string]any
281 err := json.Unmarshal([]byte(toolCall.Input), &result)
282 if err != nil {
283 continue
284 }
285 parts = append(parts, &genai.Part{
286 FunctionCall: &genai.FunctionCall{
287 ID: toolCall.ToolCallID,
288 Name: toolCall.ToolName,
289 Args: result,
290 },
291 })
292 }
293 }
294 if len(parts) > 0 {
295 content = append(content, &genai.Content{
296 Role: genai.RoleModel,
297 Parts: parts,
298 })
299 }
300 case ai.MessageRoleTool:
301 var parts []*genai.Part
302 for _, part := range msg.Content {
303 switch part.GetType() {
304 case ai.ContentTypeToolResult:
305 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
306 if !ok {
307 continue
308 }
309 var toolCall ai.ToolCallPart
310 for _, m := range prompt {
311 if m.Role == ai.MessageRoleAssistant {
312 for _, content := range m.Content {
313 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
314 if !ok {
315 continue
316 }
317 if tc.ToolCallID == result.ToolCallID {
318 toolCall = tc
319 break
320 }
321 }
322 }
323 }
324 switch result.Output.GetType() {
325 case ai.ToolResultContentTypeText:
326 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
327 if !ok {
328 continue
329 }
330 response := map[string]any{"result": content.Text}
331 parts = append(parts, &genai.Part{
332 FunctionResponse: &genai.FunctionResponse{
333 ID: result.ToolCallID,
334 Response: response,
335 Name: toolCall.ToolName,
336 },
337 })
338
339 case ai.ToolResultContentTypeError:
340 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
341 if !ok {
342 continue
343 }
344 response := map[string]any{"result": content.Error.Error()}
345 parts = append(parts, &genai.Part{
346 FunctionResponse: &genai.FunctionResponse{
347 ID: result.ToolCallID,
348 Response: response,
349 Name: toolCall.ToolName,
350 },
351 })
352
353 }
354 }
355 }
356 if len(parts) > 0 {
357 content = append(content, &genai.Content{
358 Role: genai.RoleUser,
359 Parts: parts,
360 })
361 }
362 }
363 }
364 return systemInstructions, content, warnings
365}
366
367// Generate implements ai.LanguageModel.
368func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
369 // params, err := g.prepareParams(call)
370 // if err != nil {
371 // return nil, err
372 // }
373 panic("unimplemented")
374}
375
376// Model implements ai.LanguageModel.
377func (g *languageModel) Model() string {
378 return g.modelID
379}
380
381// Provider implements ai.LanguageModel.
382func (g *languageModel) Provider() string {
383 return g.provider
384}
385
386// Stream implements ai.LanguageModel.
387func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
388 panic("unimplemented")
389}
390
391func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
392 for _, tool := range tools {
393 if tool.GetType() == ai.ToolTypeFunction {
394 ft, ok := tool.(ai.FunctionTool)
395 if !ok {
396 continue
397 }
398
399 required := []string{}
400 var properties map[string]any
401 if props, ok := ft.InputSchema["properties"]; ok {
402 properties, _ = props.(map[string]any)
403 }
404 if req, ok := ft.InputSchema["required"]; ok {
405 if reqArr, ok := req.([]string); ok {
406 required = reqArr
407 }
408 }
409 declaration := &genai.FunctionDeclaration{
410 Name: ft.Name,
411 Description: ft.Description,
412 Parameters: &genai.Schema{
413 Type: genai.TypeObject,
414 Properties: convertSchemaProperties(properties),
415 Required: required,
416 },
417 }
418 googleTools = append(googleTools, declaration)
419 continue
420 }
421 // TODO: handle provider tool calls
422 warnings = append(warnings, ai.CallWarning{
423 Type: ai.CallWarningTypeUnsupportedTool,
424 Tool: tool,
425 Message: "tool is not supported",
426 })
427 }
428 if toolChoice == nil {
429 return
430 }
431 switch *toolChoice {
432 case ai.ToolChoiceAuto:
433 googleToolChoice = &genai.ToolConfig{
434 FunctionCallingConfig: &genai.FunctionCallingConfig{
435 Mode: genai.FunctionCallingConfigModeAuto,
436 },
437 }
438 case ai.ToolChoiceRequired:
439 googleToolChoice = &genai.ToolConfig{
440 FunctionCallingConfig: &genai.FunctionCallingConfig{
441 Mode: genai.FunctionCallingConfigModeAny,
442 },
443 }
444 case ai.ToolChoiceNone:
445 googleToolChoice = &genai.ToolConfig{
446 FunctionCallingConfig: &genai.FunctionCallingConfig{
447 Mode: genai.FunctionCallingConfigModeNone,
448 },
449 }
450 default:
451 googleToolChoice = &genai.ToolConfig{
452 FunctionCallingConfig: &genai.FunctionCallingConfig{
453 Mode: genai.FunctionCallingConfigModeAny,
454 AllowedFunctionNames: []string{
455 string(*toolChoice),
456 },
457 },
458 }
459 }
460 return
461}
462
463func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
464 properties := make(map[string]*genai.Schema)
465
466 for name, param := range parameters {
467 properties[name] = convertToSchema(param)
468 }
469
470 return properties
471}
472
473func convertToSchema(param any) *genai.Schema {
474 schema := &genai.Schema{Type: genai.TypeString}
475
476 paramMap, ok := param.(map[string]any)
477 if !ok {
478 return schema
479 }
480
481 if desc, ok := paramMap["description"].(string); ok {
482 schema.Description = desc
483 }
484
485 typeVal, hasType := paramMap["type"]
486 if !hasType {
487 return schema
488 }
489
490 typeStr, ok := typeVal.(string)
491 if !ok {
492 return schema
493 }
494
495 schema.Type = mapJSONTypeToGoogle(typeStr)
496
497 switch typeStr {
498 case "array":
499 schema.Items = processArrayItems(paramMap)
500 case "object":
501 if props, ok := paramMap["properties"].(map[string]any); ok {
502 schema.Properties = convertSchemaProperties(props)
503 }
504 }
505
506 return schema
507}
508
509func processArrayItems(paramMap map[string]any) *genai.Schema {
510 items, ok := paramMap["items"].(map[string]any)
511 if !ok {
512 return nil
513 }
514
515 return convertToSchema(items)
516}
517
518func mapJSONTypeToGoogle(jsonType string) genai.Type {
519 switch jsonType {
520 case "string":
521 return genai.TypeString
522 case "number":
523 return genai.TypeNumber
524 case "integer":
525 return genai.TypeInteger
526 case "boolean":
527 return genai.TypeBoolean
528 case "array":
529 return genai.TypeArray
530 case "object":
531 return genai.TypeObject
532 default:
533 return genai.TypeString // Default to string for unknown types
534 }
535}