google.go

  1package google
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/base64"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"maps"
 11	"net/http"
 12	"strings"
 13
 14	"github.com/charmbracelet/fantasy/ai"
 15	"github.com/charmbracelet/x/exp/slice"
 16	"github.com/google/uuid"
 17	"google.golang.org/genai"
 18)
 19
 20type provider struct {
 21	options options
 22}
 23
 24type options struct {
 25	apiKey  string
 26	name    string
 27	headers map[string]string
 28	client  *http.Client
 29}
 30
 31type Option = func(*options)
 32
 33func New(opts ...Option) ai.Provider {
 34	options := options{
 35		headers: map[string]string{},
 36	}
 37	for _, o := range opts {
 38		o(&options)
 39	}
 40
 41	options.name = cmp.Or(options.name, "google")
 42
 43	return &provider{
 44		options: options,
 45	}
 46}
 47
 48func WithAPIKey(apiKey string) Option {
 49	return func(o *options) {
 50		o.apiKey = apiKey
 51	}
 52}
 53
 54func WithName(name string) Option {
 55	return func(o *options) {
 56		o.name = name
 57	}
 58}
 59
 60func WithHeaders(headers map[string]string) Option {
 61	return func(o *options) {
 62		maps.Copy(o.headers, headers)
 63	}
 64}
 65
 66func WithHTTPClient(client *http.Client) Option {
 67	return func(o *options) {
 68		o.client = client
 69	}
 70}
 71
 72type languageModel struct {
 73	provider        string
 74	modelID         string
 75	client          *genai.Client
 76	providerOptions options
 77}
 78
 79// LanguageModel implements ai.Provider.
 80func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 81	cc := &genai.ClientConfig{
 82		APIKey:     g.options.apiKey,
 83		Backend:    genai.BackendGeminiAPI,
 84		HTTPClient: g.options.client,
 85	}
 86	client, err := genai.NewClient(context.Background(), cc)
 87	if err != nil {
 88		return nil, err
 89	}
 90	return &languageModel{
 91		modelID:         modelID,
 92		provider:        fmt.Sprintf("%s.generative-ai", g.options.name),
 93		providerOptions: g.options,
 94		client:          client,
 95	}, nil
 96}
 97
 98func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
 99	config := &genai.GenerateContentConfig{}
100	providerOptions := &providerOptions{}
101	if v, ok := call.ProviderOptions["google"]; ok {
102		err := ai.ParseOptions(v, providerOptions)
103		if err != nil {
104			return nil, nil, nil, err
105		}
106	}
107
108	systemInstructions, content, warnings := toGooglePrompt(call.Prompt)
109
110	if providerOptions.ThinkingConfig != nil &&
111		providerOptions.ThinkingConfig.IncludeThoughts != nil &&
112		*providerOptions.ThinkingConfig.IncludeThoughts &&
113		strings.HasPrefix(a.provider, "google.vertex.") {
114		warnings = append(warnings, ai.CallWarning{
115			Type: ai.CallWarningTypeOther,
116			Message: "The 'includeThoughts' option is only supported with the Google Vertex provider " +
117				"and might not be supported or could behave unexpectedly with the current Google provider " +
118				fmt.Sprintf("(%s)", a.provider),
119		})
120	}
121
122	isGemmaModel := strings.HasPrefix(strings.ToLower(a.modelID), "gemma-")
123
124	if isGemmaModel && systemInstructions != nil && len(systemInstructions.Parts) > 0 {
125		if len(content) > 0 && content[0].Role == genai.RoleUser {
126			systemParts := []string{}
127			for _, sp := range systemInstructions.Parts {
128				systemParts = append(systemParts, sp.Text)
129			}
130			systemMsg := strings.Join(systemParts, "\n")
131			content[0].Parts = append([]*genai.Part{
132				{
133					Text: systemMsg + "\n\n",
134				},
135			}, content[0].Parts...)
136			systemInstructions = nil
137		}
138	}
139
140	config.SystemInstruction = systemInstructions
141
142	if call.MaxOutputTokens != nil {
143		config.MaxOutputTokens = int32(*call.MaxOutputTokens) //nolint: gosec
144	}
145
146	if call.Temperature != nil {
147		tmp := float32(*call.Temperature)
148		config.Temperature = &tmp
149	}
150	if call.TopK != nil {
151		tmp := float32(*call.TopK)
152		config.TopK = &tmp
153	}
154	if call.TopP != nil {
155		tmp := float32(*call.TopP)
156		config.TopP = &tmp
157	}
158	if call.FrequencyPenalty != nil {
159		tmp := float32(*call.FrequencyPenalty)
160		config.FrequencyPenalty = &tmp
161	}
162	if call.PresencePenalty != nil {
163		tmp := float32(*call.PresencePenalty)
164		config.PresencePenalty = &tmp
165	}
166
167	if providerOptions.ThinkingConfig != nil {
168		config.ThinkingConfig = &genai.ThinkingConfig{}
169		if providerOptions.ThinkingConfig.IncludeThoughts != nil {
170			config.ThinkingConfig.IncludeThoughts = *providerOptions.ThinkingConfig.IncludeThoughts
171		}
172		if providerOptions.ThinkingConfig.ThinkingBudget != nil {
173			tmp := int32(*providerOptions.ThinkingConfig.ThinkingBudget) //nolint: gosec
174			config.ThinkingConfig.ThinkingBudget = &tmp
175		}
176	}
177	for _, safetySetting := range providerOptions.SafetySettings {
178		config.SafetySettings = append(config.SafetySettings, &genai.SafetySetting{
179			Category:  genai.HarmCategory(safetySetting.Category),
180			Threshold: genai.HarmBlockThreshold(safetySetting.Threshold),
181		})
182	}
183	if providerOptions.CachedContent != "" {
184		config.CachedContent = providerOptions.CachedContent
185	}
186
187	if len(call.Tools) > 0 {
188		tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
189		config.ToolConfig = toolChoice
190		config.Tools = append(config.Tools, &genai.Tool{
191			FunctionDeclarations: tools,
192		})
193		warnings = append(warnings, toolWarnings...)
194	}
195
196	return config, content, warnings, nil
197}
198
199func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.CallWarning) { //nolint: unparam
200	var systemInstructions *genai.Content
201	var content []*genai.Content
202	var warnings []ai.CallWarning
203
204	finishedSystemBlock := false
205	for _, msg := range prompt {
206		switch msg.Role {
207		case ai.MessageRoleSystem:
208			if finishedSystemBlock {
209				// skip multiple system messages that are separated by user/assistant messages
210				// TODO: see if we need to send error here?
211				continue
212			}
213			finishedSystemBlock = true
214
215			var systemMessages []string
216			for _, part := range msg.Content {
217				text, ok := ai.AsMessagePart[ai.TextPart](part)
218				if !ok || text.Text == "" {
219					continue
220				}
221				systemMessages = append(systemMessages, text.Text)
222			}
223			if len(systemMessages) > 0 {
224				systemInstructions = &genai.Content{
225					Parts: []*genai.Part{
226						{
227							Text: strings.Join(systemMessages, "\n"),
228						},
229					},
230				}
231			}
232		case ai.MessageRoleUser:
233			var parts []*genai.Part
234			for _, part := range msg.Content {
235				switch part.GetType() {
236				case ai.ContentTypeText:
237					text, ok := ai.AsMessagePart[ai.TextPart](part)
238					if !ok || text.Text == "" {
239						continue
240					}
241					parts = append(parts, &genai.Part{
242						Text: text.Text,
243					})
244				case ai.ContentTypeFile:
245					file, ok := ai.AsMessagePart[ai.FilePart](part)
246					if !ok {
247						continue
248					}
249					var encoded []byte
250					base64.StdEncoding.Encode(encoded, file.Data)
251					parts = append(parts, &genai.Part{
252						InlineData: &genai.Blob{
253							Data:     encoded,
254							MIMEType: file.MediaType,
255						},
256					})
257				}
258			}
259			if len(parts) > 0 {
260				content = append(content, &genai.Content{
261					Role:  genai.RoleUser,
262					Parts: parts,
263				})
264			}
265		case ai.MessageRoleAssistant:
266			var parts []*genai.Part
267			for _, part := range msg.Content {
268				switch part.GetType() {
269				case ai.ContentTypeText:
270					text, ok := ai.AsMessagePart[ai.TextPart](part)
271					if !ok || text.Text == "" {
272						continue
273					}
274					parts = append(parts, &genai.Part{
275						Text: text.Text,
276					})
277				case ai.ContentTypeToolCall:
278					toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
279					if !ok {
280						continue
281					}
282
283					var result map[string]any
284					err := json.Unmarshal([]byte(toolCall.Input), &result)
285					if err != nil {
286						continue
287					}
288					parts = append(parts, &genai.Part{
289						FunctionCall: &genai.FunctionCall{
290							ID:   toolCall.ToolCallID,
291							Name: toolCall.ToolName,
292							Args: result,
293						},
294					})
295				}
296			}
297			if len(parts) > 0 {
298				content = append(content, &genai.Content{
299					Role:  genai.RoleModel,
300					Parts: parts,
301				})
302			}
303		case ai.MessageRoleTool:
304			var parts []*genai.Part
305			for _, part := range msg.Content {
306				switch part.GetType() {
307				case ai.ContentTypeToolResult:
308					result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
309					if !ok {
310						continue
311					}
312					var toolCall ai.ToolCallPart
313					for _, m := range prompt {
314						if m.Role == ai.MessageRoleAssistant {
315							for _, content := range m.Content {
316								tc, ok := ai.AsMessagePart[ai.ToolCallPart](content)
317								if !ok {
318									continue
319								}
320								if tc.ToolCallID == result.ToolCallID {
321									toolCall = tc
322									break
323								}
324							}
325						}
326					}
327					switch result.Output.GetType() {
328					case ai.ToolResultContentTypeText:
329						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
330						if !ok {
331							continue
332						}
333						response := map[string]any{"result": content.Text}
334						parts = append(parts, &genai.Part{
335							FunctionResponse: &genai.FunctionResponse{
336								ID:       result.ToolCallID,
337								Response: response,
338								Name:     toolCall.ToolName,
339							},
340						})
341
342					case ai.ToolResultContentTypeError:
343						content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
344						if !ok {
345							continue
346						}
347						response := map[string]any{"result": content.Error.Error()}
348						parts = append(parts, &genai.Part{
349							FunctionResponse: &genai.FunctionResponse{
350								ID:       result.ToolCallID,
351								Response: response,
352								Name:     toolCall.ToolName,
353							},
354						})
355					}
356				}
357			}
358			if len(parts) > 0 {
359				content = append(content, &genai.Content{
360					Role:  genai.RoleUser,
361					Parts: parts,
362				})
363			}
364		}
365	}
366	return systemInstructions, content, warnings
367}
368
369// Generate implements ai.LanguageModel.
370func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
371	config, contents, warnings, err := g.prepareParams(call)
372	if err != nil {
373		return nil, err
374	}
375
376	lastMessage, history, ok := slice.Pop(contents)
377	if !ok {
378		return nil, errors.New("no messages to send")
379	}
380
381	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
382	if err != nil {
383		return nil, err
384	}
385
386	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
387	if err != nil {
388		return nil, err
389	}
390
391	return mapResponse(response, warnings)
392}
393
394// Model implements ai.LanguageModel.
395func (g *languageModel) Model() string {
396	return g.modelID
397}
398
399// Provider implements ai.LanguageModel.
400func (g *languageModel) Provider() string {
401	return g.provider
402}
403
404// Stream implements ai.LanguageModel.
405func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
406	config, contents, warnings, err := g.prepareParams(call)
407	if err != nil {
408		return nil, err
409	}
410
411	lastMessage, history, ok := slice.Pop(contents)
412	if !ok {
413		return nil, errors.New("no messages to send")
414	}
415
416	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
417	if err != nil {
418		return nil, err
419	}
420
421	return func(yield func(ai.StreamPart) bool) {
422		if len(warnings) > 0 {
423			if !yield(ai.StreamPart{
424				Type:     ai.StreamPartTypeWarnings,
425				Warnings: warnings,
426			}) {
427				return
428			}
429		}
430
431		var currentContent string
432		var toolCalls []ai.ToolCallContent
433		var isActiveText bool
434		var usage ai.Usage
435
436		// Stream the response
437		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
438			if err != nil {
439				yield(ai.StreamPart{
440					Type:  ai.StreamPartTypeError,
441					Error: err,
442				})
443				return
444			}
445
446			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
447				for _, part := range resp.Candidates[0].Content.Parts {
448					switch {
449					case part.Text != "":
450						delta := part.Text
451						if delta != "" {
452							if !isActiveText {
453								isActiveText = true
454								if !yield(ai.StreamPart{
455									Type: ai.StreamPartTypeTextStart,
456									ID:   "0",
457								}) {
458									return
459								}
460							}
461							if !yield(ai.StreamPart{
462								Type:  ai.StreamPartTypeTextDelta,
463								ID:    "0",
464								Delta: delta,
465							}) {
466								return
467							}
468							currentContent += delta
469						}
470					case part.FunctionCall != nil:
471						if isActiveText {
472							isActiveText = false
473							if !yield(ai.StreamPart{
474								Type: ai.StreamPartTypeTextEnd,
475								ID:   "0",
476							}) {
477								return
478							}
479						}
480
481						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
482
483						args, err := json.Marshal(part.FunctionCall.Args)
484						if err != nil {
485							yield(ai.StreamPart{
486								Type:  ai.StreamPartTypeError,
487								Error: err,
488							})
489							return
490						}
491
492						if !yield(ai.StreamPart{
493							Type:         ai.StreamPartTypeToolInputStart,
494							ID:           toolCallID,
495							ToolCallName: part.FunctionCall.Name,
496						}) {
497							return
498						}
499
500						if !yield(ai.StreamPart{
501							Type:  ai.StreamPartTypeToolInputDelta,
502							ID:    toolCallID,
503							Delta: string(args),
504						}) {
505							return
506						}
507
508						if !yield(ai.StreamPart{
509							Type: ai.StreamPartTypeToolInputEnd,
510							ID:   toolCallID,
511						}) {
512							return
513						}
514
515						if !yield(ai.StreamPart{
516							Type:             ai.StreamPartTypeToolCall,
517							ID:               toolCallID,
518							ToolCallName:     part.FunctionCall.Name,
519							ToolCallInput:    string(args),
520							ProviderExecuted: false,
521						}) {
522							return
523						}
524
525						toolCalls = append(toolCalls, ai.ToolCallContent{
526							ToolCallID:       toolCallID,
527							ToolName:         part.FunctionCall.Name,
528							Input:            string(args),
529							ProviderExecuted: false,
530						})
531					}
532				}
533			}
534
535			if resp.UsageMetadata != nil {
536				usage = mapUsage(resp.UsageMetadata)
537			}
538		}
539
540		if isActiveText {
541			if !yield(ai.StreamPart{
542				Type: ai.StreamPartTypeTextEnd,
543				ID:   "0",
544			}) {
545				return
546			}
547		}
548
549		finishReason := ai.FinishReasonStop
550		if len(toolCalls) > 0 {
551			finishReason = ai.FinishReasonToolCalls
552		}
553
554		yield(ai.StreamPart{
555			Type:         ai.StreamPartTypeFinish,
556			Usage:        usage,
557			FinishReason: finishReason,
558		})
559	}, nil
560}
561
562func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {
563	for _, tool := range tools {
564		if tool.GetType() == ai.ToolTypeFunction {
565			ft, ok := tool.(ai.FunctionTool)
566			if !ok {
567				continue
568			}
569
570			required := []string{}
571			var properties map[string]any
572			if props, ok := ft.InputSchema["properties"]; ok {
573				properties, _ = props.(map[string]any)
574			}
575			if req, ok := ft.InputSchema["required"]; ok {
576				if reqArr, ok := req.([]string); ok {
577					required = reqArr
578				}
579			}
580			declaration := &genai.FunctionDeclaration{
581				Name:        ft.Name,
582				Description: ft.Description,
583				Parameters: &genai.Schema{
584					Type:       genai.TypeObject,
585					Properties: convertSchemaProperties(properties),
586					Required:   required,
587				},
588			}
589			googleTools = append(googleTools, declaration)
590			continue
591		}
592		// TODO: handle provider tool calls
593		warnings = append(warnings, ai.CallWarning{
594			Type:    ai.CallWarningTypeUnsupportedTool,
595			Tool:    tool,
596			Message: "tool is not supported",
597		})
598	}
599	if toolChoice == nil {
600		return //nolint: nakedret
601	}
602	switch *toolChoice {
603	case ai.ToolChoiceAuto:
604		googleToolChoice = &genai.ToolConfig{
605			FunctionCallingConfig: &genai.FunctionCallingConfig{
606				Mode: genai.FunctionCallingConfigModeAuto,
607			},
608		}
609	case ai.ToolChoiceRequired:
610		googleToolChoice = &genai.ToolConfig{
611			FunctionCallingConfig: &genai.FunctionCallingConfig{
612				Mode: genai.FunctionCallingConfigModeAny,
613			},
614		}
615	case ai.ToolChoiceNone:
616		googleToolChoice = &genai.ToolConfig{
617			FunctionCallingConfig: &genai.FunctionCallingConfig{
618				Mode: genai.FunctionCallingConfigModeNone,
619			},
620		}
621	default:
622		googleToolChoice = &genai.ToolConfig{
623			FunctionCallingConfig: &genai.FunctionCallingConfig{
624				Mode: genai.FunctionCallingConfigModeAny,
625				AllowedFunctionNames: []string{
626					string(*toolChoice),
627				},
628			},
629		}
630	}
631	return //nolint: nakedret
632}
633
634func convertSchemaProperties(parameters map[string]any) map[string]*genai.Schema {
635	properties := make(map[string]*genai.Schema)
636
637	for name, param := range parameters {
638		properties[name] = convertToSchema(param)
639	}
640
641	return properties
642}
643
644func convertToSchema(param any) *genai.Schema {
645	schema := &genai.Schema{Type: genai.TypeString}
646
647	paramMap, ok := param.(map[string]any)
648	if !ok {
649		return schema
650	}
651
652	if desc, ok := paramMap["description"].(string); ok {
653		schema.Description = desc
654	}
655
656	typeVal, hasType := paramMap["type"]
657	if !hasType {
658		return schema
659	}
660
661	typeStr, ok := typeVal.(string)
662	if !ok {
663		return schema
664	}
665
666	schema.Type = mapJSONTypeToGoogle(typeStr)
667
668	switch typeStr {
669	case "array":
670		schema.Items = processArrayItems(paramMap)
671	case "object":
672		if props, ok := paramMap["properties"].(map[string]any); ok {
673			schema.Properties = convertSchemaProperties(props)
674		}
675	}
676
677	return schema
678}
679
680func processArrayItems(paramMap map[string]any) *genai.Schema {
681	items, ok := paramMap["items"].(map[string]any)
682	if !ok {
683		return nil
684	}
685
686	return convertToSchema(items)
687}
688
689func mapJSONTypeToGoogle(jsonType string) genai.Type {
690	switch jsonType {
691	case "string":
692		return genai.TypeString
693	case "number":
694		return genai.TypeNumber
695	case "integer":
696		return genai.TypeInteger
697	case "boolean":
698		return genai.TypeBoolean
699	case "array":
700		return genai.TypeArray
701	case "object":
702		return genai.TypeObject
703	default:
704		return genai.TypeString // Default to string for unknown types
705	}
706}
707
708func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
709	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
710		return nil, errors.New("no response from model")
711	}
712
713	var (
714		content      []ai.Content
715		finishReason ai.FinishReason
716		hasToolCalls bool
717		candidate    = response.Candidates[0]
718	)
719
720	for _, part := range candidate.Content.Parts {
721		switch {
722		case part.Text != "":
723			content = append(content, ai.TextContent{Text: part.Text})
724		case part.FunctionCall != nil:
725			input, err := json.Marshal(part.FunctionCall.Args)
726			if err != nil {
727				return nil, err
728			}
729			content = append(content, ai.ToolCallContent{
730				ToolCallID:       part.FunctionCall.ID,
731				ToolName:         part.FunctionCall.Name,
732				Input:            string(input),
733				ProviderExecuted: false,
734			})
735			hasToolCalls = true
736		default:
737			return nil, fmt.Errorf("not implemented part type")
738		}
739	}
740
741	if hasToolCalls {
742		finishReason = ai.FinishReasonToolCalls
743	} else {
744		finishReason = mapFinishReason(candidate.FinishReason)
745	}
746
747	return &ai.Response{
748		Content:      content,
749		Usage:        mapUsage(response.UsageMetadata),
750		FinishReason: finishReason,
751		Warnings:     warnings,
752	}, nil
753}
754
755func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
756	switch reason {
757	case genai.FinishReasonStop:
758		return ai.FinishReasonStop
759	case genai.FinishReasonMaxTokens:
760		return ai.FinishReasonLength
761	case genai.FinishReasonSafety,
762		genai.FinishReasonBlocklist,
763		genai.FinishReasonProhibitedContent,
764		genai.FinishReasonSPII,
765		genai.FinishReasonImageSafety:
766		return ai.FinishReasonContentFilter
767	case genai.FinishReasonRecitation,
768		genai.FinishReasonLanguage,
769		genai.FinishReasonMalformedFunctionCall:
770		return ai.FinishReasonError
771	case genai.FinishReasonOther:
772		return ai.FinishReasonOther
773	default:
774		return ai.FinishReasonUnknown
775	}
776}
777
778func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
779	return ai.Usage{
780		InputTokens:         int64(usage.ToolUsePromptTokenCount),
781		OutputTokens:        int64(usage.CandidatesTokenCount),
782		TotalTokens:         int64(usage.TotalTokenCount),
783		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
784		CacheCreationTokens: int64(usage.CachedContentTokenCount),
785		CacheReadTokens:     0,
786	}
787}