From a4fd3fae6e27c970af7bd3de5c96529e384ef484 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Fri, 5 Sep 2025 17:35:48 -0300 Subject: [PATCH] feat: make google gemini work for basic + tool calls --- google/google.go | 109 +++++++++++++++++- google/slice.go | 11 ++ .../TestSimple/google-gemini-2.5-flash.yaml | 31 +++++ .../TestSimple/google-gemini-2.5-pro.yaml | 31 +++++ .../TestTool/google-gemini-2.5-flash.yaml | 59 ++++++++++ .../TestTool/google-gemini-2.5-pro.yaml | 59 ++++++++++ 6 files changed, 295 insertions(+), 5 deletions(-) create mode 100644 google/slice.go create mode 100644 providertests/testdata/TestSimple/google-gemini-2.5-flash.yaml create mode 100644 providertests/testdata/TestSimple/google-gemini-2.5-pro.yaml create mode 100644 providertests/testdata/TestTool/google-gemini-2.5-flash.yaml create mode 100644 providertests/testdata/TestTool/google-gemini-2.5-pro.yaml diff --git a/google/google.go b/google/google.go index 3335dca66c758ba4f1603deca387a73f945dd9a9..80c45217f202ebe8b6bbe9bccfaa6c42a824d0ed 100644 --- a/google/google.go +++ b/google/google.go @@ -11,12 +11,14 @@ import ( "strings" "github.com/charmbracelet/fantasy/ai" + "github.com/charmbracelet/x/exp/slice" "google.golang.org/genai" ) type provider struct { options options } + type options struct { apiKey string name string @@ -367,11 +369,27 @@ func toGooglePrompt(prompt ai.Prompt) (*genai.Content, []*genai.Content, []ai.Ca // Generate implements ai.LanguageModel. func (g *languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { - // params, err := g.prepareParams(call) - // if err != nil { - // return nil, err - // } - return nil, errors.New("unimplemented") + config, contents, warnings, err := g.prepareParams(call) + if err != nil { + return nil, err + } + + lastMessage, history, ok := slice.Pop(contents) + if !ok { + return nil, errors.New("no messages to send") + } + + chat, err := g.client.Chats.Create(ctx, g.modelID, config, history) + if err != nil { + return nil, err + } + + response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...) + if err != nil { + return nil, err + } + + return mapResponse(response, warnings) } // Model implements ai.LanguageModel. @@ -534,3 +552,84 @@ func mapJSONTypeToGoogle(jsonType string) genai.Type { return genai.TypeString // Default to string for unknown types } } + +func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarning) (*ai.Response, error) { + if len(response.Candidates) == 0 || response.Candidates[0].Content == nil { + return nil, errors.New("no response from model") + } + + var ( + content []ai.Content + finishReason ai.FinishReason + hasToolCalls bool + candidate = response.Candidates[0] + ) + + for _, part := range candidate.Content.Parts { + switch { + case part.Text != "": + content = append(content, ai.TextContent{Text: part.Text}) + case part.FunctionCall != nil: + input, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + return nil, err + } + content = append(content, ai.ToolCallContent{ + ToolCallID: part.FunctionCall.ID, + ToolName: part.FunctionCall.Name, + Input: string(input), + ProviderExecuted: false, + }) + hasToolCalls = true + default: + return nil, fmt.Errorf("not implemented part type") + } + } + + if hasToolCalls { + finishReason = ai.FinishReasonToolCalls + } else { + finishReason = mapFinishReason(candidate.FinishReason) + } + + return &ai.Response{ + Content: content, + Usage: mapUsage(response.UsageMetadata), + FinishReason: finishReason, + Warnings: warnings, + }, nil +} + +func mapFinishReason(reason genai.FinishReason) ai.FinishReason { + switch reason { + case genai.FinishReasonStop: + return ai.FinishReasonStop + case genai.FinishReasonMaxTokens: + return ai.FinishReasonLength + case genai.FinishReasonSafety, + genai.FinishReasonBlocklist, + genai.FinishReasonProhibitedContent, + genai.FinishReasonSPII, + genai.FinishReasonImageSafety: + return ai.FinishReasonContentFilter + case genai.FinishReasonRecitation, + genai.FinishReasonLanguage, + genai.FinishReasonMalformedFunctionCall: + return ai.FinishReasonError + case genai.FinishReasonOther: + return ai.FinishReasonOther + default: + return ai.FinishReasonUnknown + } +} + +func mapUsage(usage *genai.GenerateContentResponseUsageMetadata) ai.Usage { + return ai.Usage{ + InputTokens: int64(usage.ToolUsePromptTokenCount), + OutputTokens: int64(usage.CandidatesTokenCount), + TotalTokens: int64(usage.TotalTokenCount), + ReasoningTokens: int64(usage.ThoughtsTokenCount), + CacheCreationTokens: int64(usage.CachedContentTokenCount), + CacheReadTokens: 0, + } +} diff --git a/google/slice.go b/google/slice.go new file mode 100644 index 0000000000000000000000000000000000000000..215355efedd67fe59ce27cc90453c0a604bec5e3 --- /dev/null +++ b/google/slice.go @@ -0,0 +1,11 @@ +package google + +func depointerSlice[T any](s []*T) []T { + result := make([]T, 0, len(s)) + for _, v := range s { + if v != nil { + result = append(result, *v) + } + } + return result +} diff --git a/providertests/testdata/TestSimple/google-gemini-2.5-flash.yaml b/providertests/testdata/TestSimple/google-gemini-2.5-flash.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c14a2a7c570877f02aa5e7de77c9d0b98f5ea95 --- /dev/null +++ b/providertests/testdata/TestSimple/google-gemini-2.5-flash.yaml @@ -0,0 +1,31 @@ +--- +version: 2 +interactions: +- id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 180 + host: generativelanguage.googleapis.com + body: "{\"contents\":[{\"parts\":[{\"text\":\"Say hi in Portuguese\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"}}\n" + headers: + Content-Type: + - application/json + User-Agent: + - google-genai-sdk/1.23.0 gl-go/go1.24.5 + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + uncompressed: true + body: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"Olá!\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 11,\n \"candidatesTokenCount\": 2,\n \"totalTokenCount\": 39,\n \"promptTokensDetails\": [\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 11\n }\n ],\n \"thoughtsTokenCount\": 26\n },\n \"modelVersion\": \"gemini-2.5-flash\",\n \"responseId\": \"_Ei7aJ_lFZ7nz7IPwKK82Qw\"\n}\n" + headers: + Content-Type: + - application/json; charset=UTF-8 + status: 200 OK + code: 200 + duration: 1.683615083s diff --git a/providertests/testdata/TestSimple/google-gemini-2.5-pro.yaml b/providertests/testdata/TestSimple/google-gemini-2.5-pro.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75d4f9b9a190dee90fd2ed3db503fe44bb3dfcf2 --- /dev/null +++ b/providertests/testdata/TestSimple/google-gemini-2.5-pro.yaml @@ -0,0 +1,31 @@ +--- +version: 2 +interactions: +- id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 180 + host: generativelanguage.googleapis.com + body: "{\"contents\":[{\"parts\":[{\"text\":\"Say hi in Portuguese\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"}}\n" + headers: + Content-Type: + - application/json + User-Agent: + - google-genai-sdk/1.23.0 gl-go/go1.24.5 + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + uncompressed: true + body: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"Olá!\\n\\nIn Portuguese, \\\"hi\\\" can be translated as:\\n\\n* **Oi** (very common and informal)\\n* **Olá** (a bit more formal, but also widely used)\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 11,\n \"candidatesTokenCount\": 43,\n \"totalTokenCount\": 77,\n \"promptTokensDetails\": [\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 11\n }\n ],\n \"thoughtsTokenCount\": 23\n },\n \"modelVersion\": \"gemini-2.5-flash\",\n \"responseId\": \"_Ui7aL_qEoCsz7IPmMvIqQ4\"\n}\n" + headers: + Content-Type: + - application/json; charset=UTF-8 + status: 200 OK + code: 200 + duration: 870.503208ms diff --git a/providertests/testdata/TestTool/google-gemini-2.5-flash.yaml b/providertests/testdata/TestTool/google-gemini-2.5-flash.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d9b8efbecb56fda34f8cafcc64ed3f3f0719db2 --- /dev/null +++ b/providertests/testdata/TestTool/google-gemini-2.5-flash.yaml @@ -0,0 +1,59 @@ +--- +version: 2 +interactions: +- id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 481 + host: generativelanguage.googleapis.com + body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence?\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n" + headers: + Content-Type: + - application/json + User-Agent: + - google-genai-sdk/1.23.0 gl-go/go1.24.5 + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + uncompressed: true + body: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"functionCall\": {\n \"name\": \"weather\",\n \"args\": {\n \"location\": \"Florence\"\n }\n },\n \"thoughtSignature\": \"CpUCAVSoXO7jOuZmcU+HAEV/vBr3c+MDukNpwF7G8D2V1FB5+SS3o4GLSA4PqaT+0qwv6xtHan6ExItRzcphhhqVD4JUK13scMvwfy7r2r3KjbyFLq/JqQjMY5KHJXxuYu5GKuwKyYJHhHTjP0p62ORdPaUZg3umBcdZM3nD9bMneU6zIPBB4l/t0s69c/+0nooCeYI+r9s7NpI2AvqQNKjaiIeJrC1W1Qw1cUAB7L/l1ZJjqiZ5CiBTiW4WJ4pxhP9WT3vwqqv9SFi1FpvycBiLCUmLkH/rHmHHoq2ExxRUANcgqXiF0q3beAP+W1iudaGHkz9023CrN6YR5oQj6ABqYTYVDYTE8qG95MWYhWk6suiv9SD/MA==\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 52,\n \"candidatesTokenCount\": 13,\n \"totalTokenCount\": 121,\n \"promptTokensDetails\": [\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 52\n }\n ],\n \"thoughtsTokenCount\": 56\n },\n \"modelVersion\": \"gemini-2.5-flash\",\n \"responseId\": \"_ki7aPmbEbLhz7IP8vvqoQk\"\n}\n" + headers: + Content-Type: + - application/json; charset=UTF-8 + status: 200 OK + code: 200 + duration: 1.063079458s +- id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 670 + host: generativelanguage.googleapis.com + body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence?\"}],\"role\":\"user\"},{\"parts\":[{\"functionCall\":{\"args\":{\"location\":\"Florence\"},\"name\":\"weather\"}}],\"role\":\"model\"},{\"parts\":[{\"functionResponse\":{\"name\":\"weather\",\"response\":{\"result\":\"40 C\"}}}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n" + headers: + Content-Type: + - application/json + User-Agent: + - google-genai-sdk/1.23.0 gl-go/go1.24.5 + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + uncompressed: true + body: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"The weather in Florence is 40 C.\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 80,\n \"candidatesTokenCount\": 10,\n \"totalTokenCount\": 90,\n \"promptTokensDetails\": [\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 80\n }\n ]\n },\n \"modelVersion\": \"gemini-2.5-flash\",\n \"responseId\": \"_ki7aN-APaaHz7IPrL60iQE\"\n}\n" + headers: + Content-Type: + - application/json; charset=UTF-8 + status: 200 OK + code: 200 + duration: 678.656417ms diff --git a/providertests/testdata/TestTool/google-gemini-2.5-pro.yaml b/providertests/testdata/TestTool/google-gemini-2.5-pro.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14b5149ce9277d4fdfc47b709590e3adfb789ac5 --- /dev/null +++ b/providertests/testdata/TestTool/google-gemini-2.5-pro.yaml @@ -0,0 +1,59 @@ +--- +version: 2 +interactions: +- id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 481 + host: generativelanguage.googleapis.com + body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence?\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n" + headers: + Content-Type: + - application/json + User-Agent: + - google-genai-sdk/1.23.0 gl-go/go1.24.5 + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + uncompressed: true + body: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"functionCall\": {\n \"name\": \"weather\",\n \"args\": {\n \"location\": \"Florence\"\n }\n },\n \"thoughtSignature\": \"CuoBAVSoXO61VKMhhJ8LyVTCpi9Wo5Mak4xOizBZB9m9DlJngMKxObtT2VSXeZoBWIjXv7MfNZy0wldQImsslSMFK7HGKWc84fxnvkzNydv9MJSXSD+PcNKDSgPbthGCY4nt5PVxkM4DKYvW+k3YC+yTOHzKwr30SMig0YXpYQ7RUHOk5Saz8hwegYyWR/nVoGpkDq+1vQRB5vaU2RZbTlW77yFalUgAPv6uKgGArCKNbSCdhA+bOXRhAlsc8FwRF5iEzwYQSlkaACtp3MntyPr5AWHJageAtUjujoku9ZiObIX7hJ17gnaCY3NA\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 52,\n \"candidatesTokenCount\": 13,\n \"totalTokenCount\": 111,\n \"promptTokensDetails\": [\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 52\n }\n ],\n \"thoughtsTokenCount\": 46\n },\n \"modelVersion\": \"gemini-2.5-flash\",\n \"responseId\": \"_0i7aODLOP7gz7IPpObKgAc\"\n}\n" + headers: + Content-Type: + - application/json; charset=UTF-8 + status: 200 OK + code: 200 + duration: 866.858875ms +- id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 670 + host: generativelanguage.googleapis.com + body: "{\"contents\":[{\"parts\":[{\"text\":\"What's the weather in Florence?\"}],\"role\":\"user\"},{\"parts\":[{\"functionCall\":{\"args\":{\"location\":\"Florence\"},\"name\":\"weather\"}}],\"role\":\"model\"},{\"parts\":[{\"functionResponse\":{\"name\":\"weather\",\"response\":{\"result\":\"40 C\"}}}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Get weather information for a location\",\"name\":\"weather\",\"parameters\":{\"properties\":{\"location\":{\"description\":\"the city\",\"type\":\"STRING\"}},\"required\":[\"location\"],\"type\":\"OBJECT\"}}]}]}\n" + headers: + Content-Type: + - application/json + User-Agent: + - google-genai-sdk/1.23.0 gl-go/go1.24.5 + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + uncompressed: true + body: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"The weather in Florence is 40 C.\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 80,\n \"candidatesTokenCount\": 10,\n \"totalTokenCount\": 90,\n \"promptTokensDetails\": [\n {\n \"modality\": \"TEXT\",\n \"tokenCount\": 80\n }\n ]\n },\n \"modelVersion\": \"gemini-2.5-flash\",\n \"responseId\": \"AEm7aKbpJdDVz7IPlO776Ak\"\n}\n" + headers: + Content-Type: + - application/json; charset=UTF-8 + status: 200 OK + code: 200 + duration: 688.650291ms