google.go

  1package google
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/base64"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"maps"
 11	"net/http"
 12	"strings"
 13
 14	"github.com/charmbracelet/fantasy/ai"
 15	"github.com/charmbracelet/x/exp/slice"
 16	"github.com/google/uuid"
 17	"google.golang.org/genai"
 18)
 19
 20type provider struct {
 21	options options
 22}
 23
 24type options struct {
 25	apiKey  string
 26	name    string
 27	headers map[string]string
 28	client  *http.Client
 29}
 30
 31type Option = func(*options)
 32
 33func New(opts ...Option) ai.Provider {
 34	options := options{
 35		headers: map[string]string{},
 36	}
 37	for _, o := range opts {
 38		o(&options)
 39	}
 40
 41	if options.name == "" {
 42		options.name = "google"
 43	}
 44
 45	return &provider{
 46		options: options,
 47	}
 48}
 49
 50func WithAPIKey(apiKey string) Option {
 51	return func(o *options) {
 52		o.apiKey = apiKey
 53	}
 54}
 55
 56func WithName(name string) Option {
 57	return func(o *options) {
 58		o.name = name
 59	}
 60}
 61
 62func WithHeaders(headers map[string]string) Option {
 63	return func(o *options) {
 64		maps.Copy(o.headers, headers)
 65	}
 66}
 67
 68func WithHTTPClient(client *http.Client) Option {
 69	return func(o *options) {
 70		o.client = client
 71	}
 72}
 73
 74type languageModel struct {
 75	provider        string
 76	modelID         string
 77	client          *genai.Client
 78	providerOptions options
 79}
 80
 81// LanguageModel implements ai.Provider.
 82func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 83	cc := &genai.ClientConfig{
 84		APIKey:     g.options.apiKey,
 85		Backend:    genai.BackendGeminiAPI,
 86		HTTPClient: g.options.client,
 87	}
 88	client, err := genai.NewClient(context.Background(), cc)
 89	if err != nil {
 90		return nil, err
 91	}
 92	return &languageModel{
 93		modelID:         modelID,
 94		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 95		providerOptions: g.options,
 96		client:          client,
 97	}, nil
 98}
 99
100func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
101	config := &genai.GenerateContentConfig{}
102	providerOptions := &providerOptions{}
103	if v, ok := call.ProviderOptions["google"]; ok {
104		err := ai.ParseOptions(v, providerOptions)
105		if err != nil {
106			return nil, nil, nil, err
107		}
108	}
109
110	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
111
112	if providerOptions.ThinkingConfig != nil &&
113		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
114		*providerOptions.ThinkingConfig.IncludeThoughts &&
115		strings.HasPrefix(a.provider, "google.vertex.") {
116		warnings = append(warnings, ai.CallWarning{
117			Type: ai.CallWarningTypeOther,
118			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
119				"and might not be supported or could behave unexpectedly with the current Google provider " +
120				fmt.Sprintf("(%s)", a.provider),
121		})
122	}
123
124	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
125
126	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
127		if len(content) > 0 && content[0].Role == genai.RoleUser {
128			systemParts := []string{}
129			for _, sp := range systemInstructions.Parts {
130				systemParts = append(systemParts, sp.Text)
131			}
132			systemMsg := strings.Join(systemParts, "\n")
133			content[0].Parts = append([]*genai.Part{
134				{
135					Text: systemMsg + "\n\n",
136				},
137			}, content[0].Parts...)
138			systemInstructions = nil
139		}
140	}
141
142	config.SystemInstruction = systemInstructions
143
144	if call.MaxOutputTokens != nil {
145		config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
146	}
147
148	if call.Temperature != nil {
149		tmp := float32(*call.Temperature)
150		config.Temperature = &tmp
151	}
152	if call.TopK != nil {
153		tmp := float32(*call.TopK)
154		config.TopK = &tmp
155	}
156	if call.TopP != nil {
157		tmp := float32(*call.TopP)
158		config.TopP = &tmp
159	}
160	if call.FrequencyPenalty != nil {
161		tmp := float32(*call.FrequencyPenalty)
162		config.FrequencyPenalty = &tmp
163	}
164	if call.PresencePenalty != nil {
165		tmp := float32(*call.PresencePenalty)
166		config.PresencePenalty = &tmp
167	}
168
169	if providerOptions.ThinkingConfig != nil {
170		config.ThinkingConfig = &genai.ThinkingConfig{}
171		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
172			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
173		}
174		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
175			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
176			config.ThinkingConfig.ThinkingBudget = &tmp
177		}
178	}
179	for _, safetySetting := range providerOptions.SafetySettings {
180		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
181			Category:  genai.HarmCategory(safetySetting.Category),
182			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
183		})
184	}
185	if providerOptions.CachedContent != "" {
186		config.CachedContent = providerOptions.CachedContent
187	}
188
189	if len(call.Tools) > 0 {
190		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
191		config.ToolConfig = toolChoice
192		config.Tools = append(config.Tools, &genai.Tool{
193			FunctionDeclarations: tools,
194		})
195		warnings = append(warnings, toolWarnings...)
196	}
197
198	return config, content, warnings, nil
199}
200
201func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
202	var systemInstructions *genai.Content
203	var content []*genai.Content
204	var warnings []ai.CallWarning
205
206	finishedSystemBlock := false
207	for _, msg := range prompt {
208		switch msg.Role {
209		case ai.MessageRoleSystem:
210			if finishedSystemBlock {
211				// skip multiple system messages that are separated by user/assistant messages
212				// TODO: see if we need to send error here?
213				continue
214			}
215			finishedSystemBlock = true
216
217			var systemMessages []string
218			for _, part := range msg.Content {
219				text, ok := ai.AsMessagePart[ai.TextPart](part)
220				if !ok || text.Text == "" {
221					continue
222				}
223				systemMessages = append(systemMessages, text.Text)
224			}
225			if len(systemMessages) > 0 {
226				systemInstructions = &genai.Content{
227					Parts: []*genai.Part{
228						{
229							Text: strings.Join(systemMessages, "\n"),
230						},
231					},
232				}
233			}
234		case ai.MessageRoleUser:
235			var parts []*genai.Part
236			for _, part := range msg.Content {
237				switch part.GetType() {
238				case ai.ContentTypeText:
239					text, ok := ai.AsMessagePart[ai.TextPart](part)
240					if !ok || text.Text == "" {
241						continue
242					}
243					parts = append(parts, &genai.Part{
244						Text: text.Text,
245					})
246				case ai.ContentTypeFile:
247					file, ok := ai.AsMessagePart[ai.FilePart](part)
248					if !ok {
249						continue
250					}
251					var encoded []byte
252					base64.StdEncoding.Encode(encoded, file.Data)
253					parts = append(parts, &genai.Part{
254						InlineData: &genai.Blob{
255							Data:     encoded,
256							MIMEType: file.MediaType,
257						},
258					})
259				}
260			}
261			if len(parts) > 0 {
262				content = append(content, &genai.Content{
263					Role:  genai.RoleUser,
264					Parts: parts,
265				})
266			}
267		case ai.MessageRoleAssistant:
268			var parts []*genai.Part
269			for _, part := range msg.Content {
270				switch part.GetType() {
271				case ai.ContentTypeText:
272					text, ok := ai.AsMessagePart[ai.TextPart](part)
273					if !ok || text.Text == "" {
274						continue
275					}
276					parts = append(parts, &genai.Part{
277						Text: text.Text,
278					})
279				case ai.ContentTypeToolCall:
280					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
281					if !ok {
282						continue
283					}
284
285					var result map[string]any
286					err := json.Unmarshal([]byte(toolCall.Input), &result)
287					if err != nil {
288						continue
289					}
290					parts = append(parts, &genai.Part{
291						FunctionCall: &genai.FunctionCall{
292							ID:   toolCall.ToolCallID,
293							Name: toolCall.ToolName,
294							Args: result,
295						},
296					})
297				}
298			}
299			if len(parts) > 0 {
300				content = append(content, &genai.Content{
301					Role:  genai.RoleModel,
302					Parts: parts,
303				})
304			}
305		case ai.MessageRoleTool:
306			var parts []*genai.Part
307			for _, part := range msg.Content {
308				switch part.GetType() {
309				case ai.ContentTypeToolResult:
310					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
311					if !ok {
312						continue
313					}
314					var toolCall ai.ToolCallPart
315					for _, m := range prompt {
316						if m.Role == ai.MessageRoleAssistant {
317							for _, content := range m.Content {
318								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
319								if !ok {
320									continue
321								}
322								if tc.ToolCallID == result.ToolCallID {
323									toolCall = tc
324									break
325								}
326							}
327						}
328					}
329					switch result.Output.GetType() {
330					case ai.ToolResultContentTypeText:
331						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
332						if !ok {
333							continue
334						}
335						response := map[string]any{"result": content.Text}
336						parts = append(parts, &genai.Part{
337							FunctionResponse: &genai.FunctionResponse{
338								ID:       result.ToolCallID,
339								Response: response,
340								Name:     toolCall.ToolName,
341							},
342						})
343
344					case ai.ToolResultContentTypeError:
345						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
346						if !ok {
347							continue
348						}
349						response := map[string]any{"result": content.Error.Error()}
350						parts = append(parts, &genai.Part{
351							FunctionResponse: &genai.FunctionResponse{
352								ID:       result.ToolCallID,
353								Response: response,
354								Name:     toolCall.ToolName,
355							},
356						})
357					}
358				}
359			}
360			if len(parts) > 0 {
361				content = append(content, &genai.Content{
362					Role:  genai.RoleUser,
363					Parts: parts,
364				})
365			}
366		}
367	}
368	return systemInstructions, content, warnings
369}
370
371// Generate implements ai.LanguageModel.
372func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
373	config, contents, warnings, err := g.prepareParams(call)
374	if err != nil {
375		return nil, err
376	}
377
378	lastMessage, history, ok := slice.Pop(contents)
379	if !ok {
380		return nil, errors.New("no messages to send")
381	}
382
383	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
384	if err != nil {
385		return nil, err
386	}
387
388	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
389	if err != nil {
390		return nil, err
391	}
392
393	return mapResponse(response, warnings)
394}
395
396// Model implements ai.LanguageModel.
397func (g *languageModel) Model() string {
398	return g.modelID
399}
400
401// Provider implements ai.LanguageModel.
402func (g *languageModel) Provider() string {
403	return g.provider
404}
405
406// Stream implements ai.LanguageModel.
407func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
408	config, contents, warnings, err := g.prepareParams(call)
409	if err != nil {
410		return nil, err
411	}
412
413	lastMessage, history, ok := slice.Pop(contents)
414	if !ok {
415		return nil, errors.New("no messages to send")
416	}
417
418	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
419	if err != nil {
420		return nil, err
421	}
422
423	return func(yield func(ai.StreamPart) bool) {
424		if len(warnings) > 0 {
425			if !yield(ai.StreamPart{
426				Type:     ai.StreamPartTypeWarnings,
427				Warnings: warnings,
428			}) {
429				return
430			}
431		}
432
433		var currentContent string
434		var toolCalls []ai.ToolCallContent
435		var isActiveText bool
436		var usage ai.Usage
437
438		// Stream the response
439		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
440			if err != nil {
441				yield(ai.StreamPart{
442					Type:  ai.StreamPartTypeError,
443					Error: err,
444				})
445				return
446			}
447
448			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
449				for _, part := range resp.Candidates[0].Content.Parts {
450					switch {
451					case part.Text != "":
452						delta := part.Text
453						if delta != "" {
454							if !isActiveText {
455								isActiveText = true
456								if !yield(ai.StreamPart{
457									Type: ai.StreamPartTypeTextStart,
458									ID:   "0",
459								}) {
460									return
461								}
462							}
463							if !yield(ai.StreamPart{
464								Type:  ai.StreamPartTypeTextDelta,
465								ID:    "0",
466								Delta: delta,
467							}) {
468								return
469							}
470							currentContent += delta
471						}
472					case part.FunctionCall != nil:
473						if isActiveText {
474							isActiveText = false
475							if !yield(ai.StreamPart{
476								Type: ai.StreamPartTypeTextEnd,
477								ID:   "0",
478							}) {
479								return
480							}
481						}
482
483						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
484
485						args, err := json.Marshal(part.FunctionCall.Args)
486						if err != nil {
487							yield(ai.StreamPart{
488								Type:  ai.StreamPartTypeError,
489								Error: err,
490							})
491							return
492						}
493
494						if !yield(ai.StreamPart{
495							Type:         ai.StreamPartTypeToolInputStart,
496							ID:           toolCallID,
497							ToolCallName: part.FunctionCall.Name,
498						}) {
499							return
500						}
501
502						if !yield(ai.StreamPart{
503							Type:  ai.StreamPartTypeToolInputDelta,
504							ID:    toolCallID,
505							Delta: string(args),
506						}) {
507							return
508						}
509
510						if !yield(ai.StreamPart{
511							Type: ai.StreamPartTypeToolInputEnd,
512							ID:   toolCallID,
513						}) {
514							return
515						}
516
517						if !yield(ai.StreamPart{
518							Type:             ai.StreamPartTypeToolCall,
519							ID:               toolCallID,
520							ToolCallName:     part.FunctionCall.Name,
521							ToolCallInput:    string(args),
522							ProviderExecuted: false,
523						}) {
524							return
525						}
526
527						toolCalls = append(toolCalls, ai.ToolCallContent{
528							ToolCallID:       toolCallID,
529							ToolName:         part.FunctionCall.Name,
530							Input:            string(args),
531							ProviderExecuted: false,
532						})
533					}
534				}
535			}
536
537			if resp.UsageMetadata != nil {
538				usage = mapUsage(resp.UsageMetadata)
539			}
540		}
541
542		if isActiveText {
543			if !yield(ai.StreamPart{
544				Type: ai.StreamPartTypeTextEnd,
545				ID:   "0",
546			}) {
547				return
548			}
549		}
550
551		finishReason := ai.FinishReasonStop
552		if len(toolCalls) > 0 {
553			finishReason = ai.FinishReasonToolCalls
554		}
555
556		yield(ai.StreamPart{
557			Type:         ai.StreamPartTypeFinish,
558			Usage:        usage,
559			FinishReason: finishReason,
560		})
561	}, nil
562}
563
564func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
565	for _, tool := range tools {
566		if tool.GetType() == ai.ToolTypeFunction {
567			ft, ok := tool.(ai.FunctionTool)
568			if !ok {
569				continue
570			}
571
572			required := []string{}
573			var properties map[string]any
574			if props, ok := ft.InputSchema["properties"]; ok {
575				properties, _ = props.(map[string]any)
576			}
577			if req, ok := ft.InputSchema["required"]; ok {
578				if reqArr, ok := req.([]string); ok {
579					required = reqArr
580				}
581			}
582			declaration := &genai.FunctionDeclaration{
583				Name:        ft.Name,
584				Description: ft.Description,
585				Parameters: &genai.Schema{
586					Type:       genai.TypeObject,
587					Properties: convertSchemaProperties(properties),
588					Required:   required,
589				},
590			}
591			googleTools = append(googleTools, declaration)
592			continue
593		}
594		// TODO: handle provider tool calls
595		warnings = append(warnings, ai.CallWarning{
596			Type:    ai.CallWarningTypeUnsupportedTool,
597			Tool:    tool,
598			Message: "tool is not supported",
599		})
600	}
601	if toolChoice == nil {
602		return //nolint: nakedret
603	}
604	switch *toolChoice {
605	case ai.ToolChoiceAuto:
606		googleToolChoice = &genai.ToolConfig{
607			FunctionCallingConfig: &genai.FunctionCallingConfig{
608				Mode: genai.FunctionCallingConfigModeAuto,
609			},
610		}
611	case ai.ToolChoiceRequired:
612		googleToolChoice = &genai.ToolConfig{
613			FunctionCallingConfig: &genai.FunctionCallingConfig{
614				Mode: genai.FunctionCallingConfigModeAny,
615			},
616		}
617	case ai.ToolChoiceNone:
618		googleToolChoice = &genai.ToolConfig{
619			FunctionCallingConfig: &genai.FunctionCallingConfig{
620				Mode: genai.FunctionCallingConfigModeNone,
621			},
622		}
623	default:
624		googleToolChoice = &genai.ToolConfig{
625			FunctionCallingConfig: &genai.FunctionCallingConfig{
626				Mode: genai.FunctionCallingConfigModeAny,
627				AllowedFunctionNames: []string{
628					string(*toolChoice),
629				},
630			},
631		}
632	}
633	return //nolint: nakedret
634}
635
636func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
637	properties := make(map[string]*genai.Schema)
638
639	for name, param := range parameters {
640		properties[name] = convertToSchema(param)
641	}
642
643	return properties
644}
645
646func convertToSchema(param any) *genai.Schema {
647	schema := &genai.Schema{Type: genai.TypeString}
648
649	paramMap, ok := param.(map[string]any)
650	if !ok {
651		return schema
652	}
653
654	if desc, ok := paramMap["description"].(string); ok {
655		schema.Description = desc
656	}
657
658	typeVal, hasType := paramMap["type"]
659	if !hasType {
660		return schema
661	}
662
663	typeStr, ok := typeVal.(string)
664	if !ok {
665		return schema
666	}
667
668	schema.Type = mapJSONTypeToGoogle(typeStr)
669
670	switch typeStr {
671	case "array":
672		schema.Items = processArrayItems(paramMap)
673	case "object":
674		if props, ok := paramMap["properties"].(map[string]any); ok {
675			schema.Properties = convertSchemaProperties(props)
676		}
677	}
678
679	return schema
680}
681
682func processArrayItems(paramMap map[string]any) *genai.Schema {
683	items, ok := paramMap["items"].(map[string]any)
684	if !ok {
685		return nil
686	}
687
688	return convertToSchema(items)
689}
690
691func mapJSONTypeToGoogle(jsonType string) genai.Type {
692	switch jsonType {
693	case "string":
694		return genai.TypeString
695	case "number":
696		return genai.TypeNumber
697	case "integer":
698		return genai.TypeInteger
699	case "boolean":
700		return genai.TypeBoolean
701	case "array":
702		return genai.TypeArray
703	case "object":
704		return genai.TypeObject
705	default:
706		return genai.TypeString // Default to string for unknown types
707	}
708}
709
710func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
711	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
712		return nil, errors.New("no response from model")
713	}
714
715	var (
716		content      []ai.Content
717		finishReason ai.FinishReason
718		hasToolCalls bool
719		candidate    = response.Candidates[0]
720	)
721
722	for _, part := range candidate.Content.Parts {
723		switch {
724		case part.Text != "":
725			content = append(content, ai.TextContent{Text: part.Text})
726		case part.FunctionCall != nil:
727			input, err := json.Marshal(part.FunctionCall.Args)
728			if err != nil {
729				return nil, err
730			}
731			content = append(content, ai.ToolCallContent{
732				ToolCallID:       part.FunctionCall.ID,
733				ToolName:         part.FunctionCall.Name,
734				Input:            string(input),
735				ProviderExecuted: false,
736			})
737			hasToolCalls = true
738		default:
739			return nil, fmt.Errorf("not implemented part type")
740		}
741	}
742
743	if hasToolCalls {
744		finishReason = ai.FinishReasonToolCalls
745	} else {
746		finishReason = mapFinishReason(candidate.FinishReason)
747	}
748
749	return &ai.Response{
750		Content:      content,
751		Usage:        mapUsage(response.UsageMetadata),
752		FinishReason: finishReason,
753		Warnings:     warnings,
754	}, nil
755}
756
757func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
758	switch reason {
759	case genai.FinishReasonStop:
760		return ai.FinishReasonStop
761	case genai.FinishReasonMaxTokens:
762		return ai.FinishReasonLength
763	case genai.FinishReasonSafety,
764		genai.FinishReasonBlocklist,
765		genai.FinishReasonProhibitedContent,
766		genai.FinishReasonSPII,
767		genai.FinishReasonImageSafety:
768		return ai.FinishReasonContentFilter
769	case genai.FinishReasonRecitation,
770		genai.FinishReasonLanguage,
771		genai.FinishReasonMalformedFunctionCall:
772		return ai.FinishReasonError
773	case genai.FinishReasonOther:
774		return ai.FinishReasonOther
775	default:
776		return ai.FinishReasonUnknown
777	}
778}
779
780func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
781	return ai.Usage{
782		InputTokens:         int64(usage.ToolUsePromptTokenCount),
783		OutputTokens:        int64(usage.CandidatesTokenCount),
784		TotalTokens:         int64(usage.TotalTokenCount),
785		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
786		CacheCreationTokens: int64(usage.CachedContentTokenCount),
787		CacheReadTokens:     0,
788	}
789}