feat: make google gemini work for basic + tool calls

Andrey Nering created

Change summary

google/google.go                                               | 109 +++
google/slice.go                                                |  11 
providertests/testdata/TestSimple/google-gemini-2.5-flash.yaml |  31 +
providertests/testdata/TestSimple/google-gemini-2.5-pro.yaml   |  31 +
providertests/testdata/TestTool/google-gemini-2.5-flash.yaml   |  24 
providertests/testdata/TestTool/google-gemini-2.5-pro.yaml     |  24 
6 files changed, 225 insertions(+), 5 deletions(-)

Detailed changes

google/google.go ๐Ÿ”—

@@ -11,12 +11,14 @@ import (
 	"strings"
 
 	"github.com/charmbracelet/fantasy/ai"
+	"github.com/charmbracelet/x/exp/slice"
 	"google.golang.org/genai"
 )
 
 type provider struct {
 	options options
 }
+
 type options struct {
 	apiKey  string
 	name    string
@@ -367,11 +369,27 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca
 
 // Generate implements ai.LanguageModel.
 func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
-	// params, err := g.prepareParams(call)
-	// if err != nil {
-	// 	return nil, err
-	// }
-	return nil, errors.New("unimplemented")
+	config, contents, warnings, err := g.prepareParams(call)
+	if err != nil {
+		return nil, err
+	}
+
+	lastMessage, history, ok := slice.Pop(contents)
+	if !ok {
+		return nil, errors.New("no messages to send")
+	}
+
+	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
+	if err != nil {
+		return nil, err
+	}
+
+	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
+	if err != nil {
+		return nil, err
+	}
+
+	return mapResponse(response, warnings)
 }
 
 // Model implements ai.LanguageModel.
@@ -534,3 +552,84 @@ func mapJSONTypeToGoogle(jsonType string) genai.Type {
 		return genai.TypeString // Default to string for unknown types
 	}
 }
+
+func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) {
+	if len(response.Candidates) == 0 || response.Candidates[0].Content == nil {
+		return nil, errors.New("no response from model")
+	}
+
+	var (
+		content      []ai.Content
+		finishReason ai.FinishReason
+		hasToolCalls bool
+		candidate    = response.Candidates[0]
+	)
+
+	for _, part := range candidate.Content.Parts {
+		switch {
+		case part.Text != "":
+			content = append(content, ai.TextContent{Text: part.Text})
+		case part.FunctionCall != nil:
+			input, err := json.Marshal(part.FunctionCall.Args)
+			if err != nil {
+				return nil, err
+			}
+			content = append(content, ai.ToolCallContent{
+				ToolCallID:       part.FunctionCall.ID,
+				ToolName:         part.FunctionCall.Name,
+				Input:            string(input),
+				ProviderExecuted: false,
+			})
+			hasToolCalls = true
+		default:
+			return nil, fmt.Errorf("not implemented part type")
+		}
+	}
+
+	if hasToolCalls {
+		finishReason = ai.FinishReasonToolCalls
+	} else {
+		finishReason = mapFinishReason(candidate.FinishReason)
+	}
+
+	return &ai.Response{
+		Content:      content,
+		Usage:        mapUsage(response.UsageMetadata),
+		FinishReason: finishReason,
+		Warnings:     warnings,
+	}, nil
+}
+
+func mapFinishReason(reason genai.FinishReason) ai.FinishReason {
+	switch reason {
+	case genai.FinishReasonStop:
+		return ai.FinishReasonStop
+	case genai.FinishReasonMaxTokens:
+		return ai.FinishReasonLength
+	case genai.FinishReasonSafety,
+		genai.FinishReasonBlocklist,
+		genai.FinishReasonProhibitedContent,
+		genai.FinishReasonSPII,
+		genai.FinishReasonImageSafety:
+		return ai.FinishReasonContentFilter
+	case genai.FinishReasonRecitation,
+		genai.FinishReasonLanguage,
+		genai.FinishReasonMalformedFunctionCall:
+		return ai.FinishReasonError
+	case genai.FinishReasonOther:
+		return ai.FinishReasonOther
+	default:
+		return ai.FinishReasonUnknown
+	}
+}
+
+func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage {
+	return ai.Usage{
+		InputTokens:         int64(usage.ToolUsePromptTokenCount),
+		OutputTokens:        int64(usage.CandidatesTokenCount),
+		TotalTokens:         int64(usage.TotalTokenCount),
+		ReasoningTokens:     int64(usage.ThoughtsTokenCount),
+		CacheCreationTokens: int64(usage.CachedContentTokenCount),
+		CacheReadTokens:     0,
+	}
+}

google/slice.go ๐Ÿ”—

@@ -0,0 +1,11 @@
+package google
+
+func depointerSlice[T any](s []*T) []T {
+	result := make([]T, 0, len(s))
+	for _, v := range s {
+		if v != nil {
+			result = append(result, *v)
+		}
+	}
+	return result
+}

providertests/testdata/TestSimple/google-gemini-2.5-flash.yaml ๐Ÿ”—

@@ -0,0 +1,31 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 180
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"Say hi in Portuguese\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"}}\n"
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    uncompressed: true
+    body: "{\n  \"candidates\": [\n    {\n      \"content\": {\n        \"parts\": [\n          {\n            \"text\": \"Olรก!\"\n          }\n        ],\n        \"role\": \"model\"\n      },\n      \"finishReason\": \"STOP\",\n      \"index\": 0\n    }\n  ],\n  \"usageMetadata\": {\n    \"promptTokenCount\": 11,\n    \"candidatesTokenCount\": 2,\n    \"totalTokenCount\": 39,\n    \"promptTokensDetails\": [\n      {\n        \"modality\": \"TEXT\",\n        \"tokenCount\": 11\n      }\n    ],\n    \"thoughtsTokenCount\": 26\n  },\n  \"modelVersion\": \"gemini-2.5-flash\",\n  \"responseId\": \"_Ei7aJ_lFZ7nz7IPwKK82Qw\"\n}\n"
+    headers:
+      Content-Type:
+      - application/json; charset=UTF-8
+    status: 200 OK
+    code: 200
+    duration: 1.683615083s

providertests/testdata/TestSimple/google-gemini-2.5-pro.yaml ๐Ÿ”—

@@ -0,0 +1,31 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 180
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"Say hi in Portuguese\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"}}\n"
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    uncompressed: true
+    body: "{\n  \"candidates\": [\n    {\n      \"content\": {\n        \"parts\": [\n          {\n            \"text\": \"Olรก!\\n\\nIn Portuguese, \\\"hi\\\" can be translated as:\\n\\n*   **Oi** (very common and informal)\\n*   **Olรก** (a bit more formal, but also widely used)\"\n          }\n        ],\n        \"role\": \"model\"\n      },\n      \"finishReason\": \"STOP\",\n      \"index\": 0\n    }\n  ],\n  \"usageMetadata\": {\n    \"promptTokenCount\": 11,\n    \"candidatesTokenCount\": 43,\n    \"totalTokenCount\": 77,\n    \"promptTokensDetails\": [\n      {\n        \"modality\": \"TEXT\",\n        \"tokenCount\": 11\n      }\n    ],\n    \"thoughtsTokenCount\": 23\n  },\n  \"modelVersion\": \"gemini-2.5-flash\",\n  \"responseId\": \"_Ui7aL_qEoCsz7IPmMvIqQ4\"\n}\n"
+    headers:
+      Content-Type:
+      - application/json; charset=UTF-8
+    status: 200 OK
+    code: 200
+    duration: 870.503208ms

providertests/testdata/TestTool/google-gemini-2.5-flash.yaml ๐Ÿ”—

@@ -0,0 +1,59 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 481
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence?\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n"
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    uncompressed: true

providertests/testdata/TestTool/google-gemini-2.5-pro.yaml ๐Ÿ”—

@@ -0,0 +1,59 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 481
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence?\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n"
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    uncompressed: true