1package google
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "maps"
10 "net/http"
11 "strings"
12
13 "github.com/charmbracelet/fantasy/ai"
14 "google.golang.org/genai"
15)
16
17type provider struct {
18 options options
19}
20type options struct {
21 apiKey string
22 name string
23 headers map[string]string
24 client *http.Client
25}
26
27type Option = func(*options)
28
29func New(opts ...Option) ai.Provider {
30 options := options{
31 headers: map[string]string{},
32 }
33 for _, o := range opts {
34 o(&options)
35 }
36
37 if options.name == "" {
38 options.name = "google"
39 }
40
41 return &provider{
42 options: options,
43 }
44}
45
46func WithAPIKey(apiKey string) Option {
47 return func(o *options) {
48 o.apiKey = apiKey
49 }
50}
51
52func WithName(name string) Option {
53 return func(o *options) {
54 o.name = name
55 }
56}
57
58func WithHeaders(headers map[string]string) Option {
59 return func(o *options) {
60 maps.Copy(o.headers, headers)
61 }
62}
63
64func WithHTTPClient(client *http.Client) Option {
65 return func(o *options) {
66 o.client = client
67 }
68}
69
70type languageModel struct {
71 provider string
72 modelID string
73 client *genai.Client
74 providerOptions options
75}
76
77// LanguageModel implements ai.Provider.
78func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
79 cc := &genai.ClientConfig{
80 APIKey: g.options.apiKey,
81 Backend: genai.BackendGeminiAPI,
82 HTTPClient: g.options.client,
83 }
84 client, err := genai.NewClient(context.Background(), cc)
85 if err != nil {
86 return nil, err
87 }
88 return &languageModel{
89 modelID: modelID,
90 provider: fmt.Sprintf("%s.generative-ai", g.options.name),
91 providerOptions: g.options,
92 client: client,
93 }, nil
94}
95
96func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
97 config := &genai.GenerateContentConfig{}
98 providerOptions := &providerOptions{}
99 if v, ok := call.ProviderOptions["google"]; ok {
100 err := ai.ParseOptions(v, providerOptions)
101 if err != nil {
102 return nil, nil, nil, err
103 }
104 }
105
106 systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
107
108 if providerOptions.ThinkingConfig != nil &&
109 providerOptions.ThinkingConfig.IncludeThoughts != nil &&
110 *providerOptions.ThinkingConfig.IncludeThoughts &&
111 strings.HasPrefix(a.provider, "google.vertex.") {
112 warnings = append(warnings, ai.CallWarning{
113 Type: ai.CallWarningTypeOther,
114 Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
115 "and might not be supported or could behave unexpectedly with the current Google provider " +
116 fmt.Sprintf("(%s)", a.provider),
117 })
118 }
119
120 isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
121
122 if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
123 if len(content) > 0 && content[0].Role == genai.RoleUser {
124 systemParts := []string{}
125 for _, sp := range systemInstructions.Parts {
126 systemParts = append(systemParts, sp.Text)
127 }
128 systemMsg := strings.Join(systemParts, "\n")
129 content[0].Parts = append([]*genai.Part{
130 {
131 Text: systemMsg + "\n\n",
132 },
133 }, content[0].Parts...)
134 systemInstructions = nil
135 }
136 }
137
138 config.SystemInstruction = systemInstructions
139
140 if call.MaxOutputTokens != nil {
141 config.MaxOutputTokens = int32(*call.MaxOutputTokens)
142 }
143
144 if call.Temperature != nil {
145 tmp := float32(*call.Temperature)
146 config.Temperature = &tmp
147 }
148 if call.TopK != nil {
149 tmp := float32(*call.TopK)
150 config.TopK = &tmp
151 }
152 if call.TopP != nil {
153 tmp := float32(*call.TopP)
154 config.TopP = &tmp
155 }
156 if call.FrequencyPenalty != nil {
157 tmp := float32(*call.FrequencyPenalty)
158 config.FrequencyPenalty = &tmp
159 }
160 if call.PresencePenalty != nil {
161 tmp := float32(*call.PresencePenalty)
162 config.PresencePenalty = &tmp
163 }
164
165 if providerOptions.ThinkingConfig != nil {
166 config.ThinkingConfig = &genai.ThinkingConfig{}
167 if providerOptions.ThinkingConfig.IncludeThoughts != nil {
168 config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
169 }
170 if providerOptions.ThinkingConfig.ThinkingBudget != nil {
171 tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget)
172 config.ThinkingConfig.ThinkingBudget = &tmp
173 }
174 }
175 for _, safetySetting := range providerOptions.SafetySettings {
176 config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
177 Category: genai.HarmCategory(safetySetting.Category),
178 Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
179 })
180 }
181 if providerOptions.CachedContent != "" {
182 config.CachedContent = providerOptions.CachedContent
183 }
184
185 if len(call.Tools) > 0 {
186 tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
187 config.ToolConfig = toolChoice
188 config.Tools = append(config.Tools, &genai.Tool{
189 FunctionDeclarations: tools,
190 })
191 warnings = append(warnings, toolWarnings...)
192 }
193
194 return config, content, warnings, nil
195}
196
197func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) {
198 var systemInstructions *genai.Content
199 var content []*genai.Content
200 var warnings []ai.CallWarning
201
202 finishedSystemBlock := false
203 for _, msg := range prompt {
204 switch msg.Role {
205 case ai.MessageRoleSystem:
206 if finishedSystemBlock {
207 // skip multiple system messages that are separated by user/assistant messages
208 // TODO: see if we need to send error here?
209 continue
210 }
211 finishedSystemBlock = true
212
213 var systemMessages []string
214 for _, part := range msg.Content {
215 text, ok := ai.AsMessagePart[ai.TextPart](part)
216 if !ok || text.Text == "" {
217 continue
218 }
219 systemMessages = append(systemMessages, text.Text)
220 }
221 if len(systemMessages) > 0 {
222 systemInstructions = &genai.Content{
223 Parts: []*genai.Part{
224 {
225 Text: strings.Join(systemMessages, "\n"),
226 },
227 },
228 }
229 }
230 case ai.MessageRoleUser:
231 var parts []*genai.Part
232 for _, part := range msg.Content {
233 switch part.GetType() {
234 case ai.ContentTypeText:
235 text, ok := ai.AsMessagePart[ai.TextPart](part)
236 if !ok || text.Text == "" {
237 continue
238 }
239 parts = append(parts, &genai.Part{
240 Text: text.Text,
241 })
242 case ai.ContentTypeFile:
243 file, ok := ai.AsMessagePart[ai.FilePart](part)
244 if !ok {
245 continue
246 }
247 var encoded []byte
248 base64.StdEncoding.Encode(encoded, file.Data)
249 parts = append(parts, &genai.Part{
250 InlineData: &genai.Blob{
251 Data: encoded,
252 MIMEType: file.MediaType,
253 },
254 })
255 }
256 }
257 if len(parts) > 0 {
258 content = append(content, &genai.Content{
259 Role: genai.RoleUser,
260 Parts: parts,
261 })
262 }
263 case ai.MessageRoleAssistant:
264 var parts []*genai.Part
265 for _, part := range msg.Content {
266 switch part.GetType() {
267 case ai.ContentTypeText:
268 text, ok := ai.AsMessagePart[ai.TextPart](part)
269 if !ok || text.Text == "" {
270 continue
271 }
272 parts = append(parts, &genai.Part{
273 Text: text.Text,
274 })
275 case ai.ContentTypeToolCall:
276 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
277 if !ok {
278 continue
279 }
280
281 var result map[string]any
282 err := json.Unmarshal([]byte(toolCall.Input), &result)
283 if err != nil {
284 continue
285 }
286 parts = append(parts, &genai.Part{
287 FunctionCall: &genai.FunctionCall{
288 ID: toolCall.ToolCallID,
289 Name: toolCall.ToolName,
290 Args: result,
291 },
292 })
293 }
294 }
295 if len(parts) > 0 {
296 content = append(content, &genai.Content{
297 Role: genai.RoleModel,
298 Parts: parts,
299 })
300 }
301 case ai.MessageRoleTool:
302 var parts []*genai.Part
303 for _, part := range msg.Content {
304 switch part.GetType() {
305 case ai.ContentTypeToolResult:
306 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
307 if !ok {
308 continue
309 }
310 var toolCall ai.ToolCallPart
311 for _, m := range prompt {
312 if m.Role == ai.MessageRoleAssistant {
313 for _, content := range m.Content {
314 tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
315 if !ok {
316 continue
317 }
318 if tc.ToolCallID == result.ToolCallID {
319 toolCall = tc
320 break
321 }
322 }
323 }
324 }
325 switch result.Output.GetType() {
326 case ai.ToolResultContentTypeText:
327 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
328 if !ok {
329 continue
330 }
331 response := map[string]any{"result": content.Text}
332 parts = append(parts, &genai.Part{
333 FunctionResponse: &genai.FunctionResponse{
334 ID: result.ToolCallID,
335 Response: response,
336 Name: toolCall.ToolName,
337 },
338 })
339
340 case ai.ToolResultContentTypeError:
341 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
342 if !ok {
343 continue
344 }
345 response := map[string]any{"result": content.Error.Error()}
346 parts = append(parts, &genai.Part{
347 FunctionResponse: &genai.FunctionResponse{
348 ID: result.ToolCallID,
349 Response: response,
350 Name: toolCall.ToolName,
351 },
352 })
353
354 }
355 }
356 }
357 if len(parts) > 0 {
358 content = append(content, &genai.Content{
359 Role: genai.RoleUser,
360 Parts: parts,
361 })
362 }
363 }
364 }
365 return systemInstructions, content, warnings
366}
367
368// Generate implements ai.LanguageModel.
369func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
370 // params, err := g.prepareParams(call)
371 // if err != nil {
372 // return nil, err
373 // }
374 return nil, errors.New("unimplemented")
375}
376
377// Model implements ai.LanguageModel.
378func (g *languageModel) Model() string {
379 return g.modelID
380}
381
382// Provider implements ai.LanguageModel.
383func (g *languageModel) Provider() string {
384 return g.provider
385}
386
387// Stream implements ai.LanguageModel.
388func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
389 return nil, errors.New("unimplemented")
390}
391
392func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
393 for _, tool := range tools {
394 if tool.GetType() == ai.ToolTypeFunction {
395 ft, ok := tool.(ai.FunctionTool)
396 if !ok {
397 continue
398 }
399
400 required := []string{}
401 var properties map[string]any
402 if props, ok := ft.InputSchema["properties"]; ok {
403 properties, _ = props.(map[string]any)
404 }
405 if req, ok := ft.InputSchema["required"]; ok {
406 if reqArr, ok := req.([]string); ok {
407 required = reqArr
408 }
409 }
410 declaration := &genai.FunctionDeclaration{
411 Name: ft.Name,
412 Description: ft.Description,
413 Parameters: &genai.Schema{
414 Type: genai.TypeObject,
415 Properties: convertSchemaProperties(properties),
416 Required: required,
417 },
418 }
419 googleTools = append(googleTools, declaration)
420 continue
421 }
422 // TODO: handle provider tool calls
423 warnings = append(warnings, ai.CallWarning{
424 Type: ai.CallWarningTypeUnsupportedTool,
425 Tool: tool,
426 Message: "tool is not supported",
427 })
428 }
429 if toolChoice == nil {
430 return
431 }
432 switch *toolChoice {
433 case ai.ToolChoiceAuto:
434 googleToolChoice = &genai.ToolConfig{
435 FunctionCallingConfig: &genai.FunctionCallingConfig{
436 Mode: genai.FunctionCallingConfigModeAuto,
437 },
438 }
439 case ai.ToolChoiceRequired:
440 googleToolChoice = &genai.ToolConfig{
441 FunctionCallingConfig: &genai.FunctionCallingConfig{
442 Mode: genai.FunctionCallingConfigModeAny,
443 },
444 }
445 case ai.ToolChoiceNone:
446 googleToolChoice = &genai.ToolConfig{
447 FunctionCallingConfig: &genai.FunctionCallingConfig{
448 Mode: genai.FunctionCallingConfigModeNone,
449 },
450 }
451 default:
452 googleToolChoice = &genai.ToolConfig{
453 FunctionCallingConfig: &genai.FunctionCallingConfig{
454 Mode: genai.FunctionCallingConfigModeAny,
455 AllowedFunctionNames: []string{
456 string(*toolChoice),
457 },
458 },
459 }
460 }
461 return
462}
463
464func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
465 properties := make(map[string]*genai.Schema)
466
467 for name, param := range parameters {
468 properties[name] = convertToSchema(param)
469 }
470
471 return properties
472}
473
474func convertToSchema(param any) *genai.Schema {
475 schema := &genai.Schema{Type: genai.TypeString}
476
477 paramMap, ok := param.(map[string]any)
478 if !ok {
479 return schema
480 }
481
482 if desc, ok := paramMap["description"].(string); ok {
483 schema.Description = desc
484 }
485
486 typeVal, hasType := paramMap["type"]
487 if !hasType {
488 return schema
489 }
490
491 typeStr, ok := typeVal.(string)
492 if !ok {
493 return schema
494 }
495
496 schema.Type = mapJSONTypeToGoogle(typeStr)
497
498 switch typeStr {
499 case "array":
500 schema.Items = processArrayItems(paramMap)
501 case "object":
502 if props, ok := paramMap["properties"].(map[string]any); ok {
503 schema.Properties = convertSchemaProperties(props)
504 }
505 }
506
507 return schema
508}
509
510func processArrayItems(paramMap map[string]any) *genai.Schema {
511 items, ok := paramMap["items"].(map[string]any)
512 if !ok {
513 return nil
514 }
515
516 return convertToSchema(items)
517}
518
519func mapJSONTypeToGoogle(jsonType string) genai.Type {
520 switch jsonType {
521 case "string":
522 return genai.TypeString
523 case "number":
524 return genai.TypeNumber
525 case "integer":
526 return genai.TypeInteger
527 case "boolean":
528 return genai.TypeBoolean
529 case "array":
530 return genai.TypeArray
531 case "object":
532 return genai.TypeObject
533 default:
534 return genai.TypeString // Default to string for unknown types
535 }
536}