feat(gemini): implement streaming

Andrey Nering created

Change summary

google/google.go                                                        | 159 
providertests/testdata/TestStream/google-gemini-2.5-flash.yaml          |  33 
providertests/testdata/TestStream/google-gemini-2.5-pro.yaml            |  33 
providertests/testdata/TestStreamWithTools/google-gemini-2.5-flash.yaml |  26 
providertests/testdata/TestStreamWithTools/google-gemini-2.5-pro.yaml   |  26 
5 files changed, 275 insertions(+), 2 deletions(-)

Detailed changes

google/google.go 🔗

@@ -1,6 +1,7 @@
 package google
 
 import (
+	"cmp"
 	"context"
 	"encoding/base64"
 	"encoding/json"
@@ -12,6 +13,7 @@ import (
 
 	"github.com/charmbracelet/fantasy/ai"
 	"github.com/charmbracelet/x/exp/slice"
+	"github.com/google/uuid"
 	"google.golang.org/genai"
 )
 
@@ -403,8 +405,161 @@ func (g *languageModel) Provider() string {
 }
 
 // Stream implements ai.LanguageModel.
-func (g *languageModel) Stream(context.Context, ai.Call) (ai.StreamResponse, error) {
-	return nil, errors.New("unimplemented")
+func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
+	config, contents, warnings, err := g.prepareParams(call)
+	if err != nil {
+		return nil, err
+	}
+
+	lastMessage, history, ok := slice.Pop(contents)
+	if !ok {
+		return nil, errors.New("no messages to send")
+	}
+
+	chat, err := g.client.Chats.Create(ctx, g.modelID, config, history)
+	if err != nil {
+		return nil, err
+	}
+
+	return func(yield func(ai.StreamPart) bool) {
+		if len(warnings) > 0 {
+			if !yield(ai.StreamPart{
+				Type:     ai.StreamPartTypeWarnings,
+				Warnings: warnings,
+			}) {
+				return
+			}
+		}
+
+		var currentContent string
+		var toolCalls []ai.ToolCallContent
+		var isActiveText bool
+		var usage ai.Usage
+
+		// Stream the response
+		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
+			if err != nil {
+				yield(ai.StreamPart{
+					Type:  ai.StreamPartTypeError,
+					Error: err,
+				})
+				return
+			}
+
+			if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+				for _, part := range resp.Candidates[0].Content.Parts {
+					switch {
+					case part.Text != "":
+						delta := part.Text
+						if delta != "" {
+							if !isActiveText {
+								isActiveText = true
+								if !yield(ai.StreamPart{
+									Type: ai.StreamPartTypeTextStart,
+									ID:   "0",
+								}) {
+									return
+								}
+							}
+							if !yield(ai.StreamPart{
+								Type:  ai.StreamPartTypeTextDelta,
+								ID:    "0",
+								Delta: delta,
+							}) {
+								return
+							}
+							currentContent += delta
+						}
+					case part.FunctionCall != nil:
+						if isActiveText {
+							isActiveText = false
+							if !yield(ai.StreamPart{
+								Type: ai.StreamPartTypeTextEnd,
+								ID:   "0",
+							}) {
+								return
+							}
+						}
+
+						toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
+
+						args, err := json.Marshal(part.FunctionCall.Args)
+						if err != nil {
+							yield(ai.StreamPart{
+								Type:  ai.StreamPartTypeError,
+								Error: err,
+							})
+							return
+						}
+
+						if !yield(ai.StreamPart{
+							Type:         ai.StreamPartTypeToolInputStart,
+							ID:           toolCallID,
+							ToolCallName: part.FunctionCall.Name,
+						}) {
+							return
+						}
+
+						if !yield(ai.StreamPart{
+							Type:  ai.StreamPartTypeToolInputDelta,
+							ID:    toolCallID,
+							Delta: string(args),
+						}) {
+							return
+						}
+
+						if !yield(ai.StreamPart{
+							Type: ai.StreamPartTypeToolInputEnd,
+							ID:   toolCallID,
+						}) {
+							return
+						}
+
+						if !yield(ai.StreamPart{
+							Type:             ai.StreamPartTypeToolCall,
+							ID:               toolCallID,
+							ToolCallName:     part.FunctionCall.Name,
+							ToolCallInput:    string(args),
+							ProviderExecuted: false,
+						}) {
+							return
+						}
+
+						toolCalls = append(toolCalls, ai.ToolCallContent{
+							ToolCallID:       toolCallID,
+							ToolName:         part.FunctionCall.Name,
+							Input:            string(args),
+							ProviderExecuted: false,
+						})
+					}
+				}
+			}
+
+			if resp.UsageMetadata != nil {
+				usage = mapUsage(resp.UsageMetadata)
+			}
+		}
+
+		if isActiveText {
+			if !yield(ai.StreamPart{
+				Type: ai.StreamPartTypeTextEnd,
+				ID:   "0",
+			}) {
+				return
+			}
+		}
+
+		finishReason := ai.FinishReasonStop
+		if len(toolCalls) > 0 {
+			finishReason = ai.FinishReasonToolCalls
+		}
+
+		yield(ai.StreamPart{
+			Type:         ai.StreamPartTypeFinish,
+			Usage:        usage,
+			FinishReason: finishReason,
+		})
+	}, nil
 }
 
 func toGoogleTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (googleTools []*genai.FunctionDeclaration, googleToolChoice *genai.ToolConfig, warnings []ai.CallWarning) {

providertests/testdata/TestStream/google-gemini-2.5-flash.yaml 🔗

@@ -0,0 +1,33 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 188
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"Count from 1 to 3 in Spanish\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"}}\n"
+    form:
+      alt:
+      - sse
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"Here you go:\\n\\n1.  **Uno**\\n2.  **Dos**\\n3.  **Tres**\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 16,\"candidatesTokenCount\": 25,\"totalTokenCount\": 74,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 16}],\"thoughtsTokenCount\": 33},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"jl3AaOWjGefyqtsPi_C6sAM\"}\r\n\r\n"
+    headers:
+      Content-Type:
+      - text/event-stream
+    status: 200 OK
+    code: 200
+    duration: 1.178272625s

providertests/testdata/TestStream/google-gemini-2.5-pro.yaml 🔗

@@ -0,0 +1,33 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 188
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"Count from 1 to 3 in Spanish\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant\"}],\"role\":\"user\"}}\n"
+    form:
+      alt:
+      - sse
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"Claro:\\n\\n1.  **Uno**\\n2.  **Dos**\\n3.  **Tres**\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 16,\"candidatesTokenCount\": 24,\"totalTokenCount\": 67,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 16}],\"thoughtsTokenCount\": 27},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"j13AaJWDJfeHqtsPv4Tk0QY\"}\r\n\r\n"
+    headers:
+      Content-Type:
+      - text/event-stream
+    status: 200 OK
+    code: 200
+    duration: 1.10918025s

providertests/testdata/TestStreamWithTools/google-gemini-2.5-flash.yaml 🔗

@@ -0,0 +1,63 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 530
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"What is 15 + 27?\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant. Use the add tool to perform calculations.\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Add two numbers\",\"name\":\"add\",\"parameters\":{\"properties\":{\"a\":{\"description\":\"first number\",\"type\":\"INTEGER\"},\"b\":{\"description\":\"second number\",\"type\":\"INTEGER\"}},\"required\":[\"a\",\"b\"],\"type\":\"OBJECT\"}}]}]}\n"
+    form:
+      alt:
+      - sse
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1

providertests/testdata/TestStreamWithTools/google-gemini-2.5-pro.yaml 🔗

@@ -0,0 +1,63 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 530
+    host: generativelanguage.googleapis.com
+    body: "{\"contents\":[{\"parts\":[{\"text\":\"What is 15 + 27?\"}],\"role\":\"user\"}],\"generationConfig\":{},\"systemInstruction\":{\"parts\":[{\"text\":\"You are a helpful assistant. Use the add tool to perform calculations.\"}],\"role\":\"user\"},\"toolConfig\":{\"functionCallingConfig\":{\"mode\":\"AUTO\"}},\"tools\":[{\"functionDeclarations\":[{\"description\":\"Add two numbers\",\"name\":\"add\",\"parameters\":{\"properties\":{\"a\":{\"description\":\"first number\",\"type\":\"INTEGER\"},\"b\":{\"description\":\"second number\",\"type\":\"INTEGER\"}},\"required\":[\"a\",\"b\"],\"type\":\"OBJECT\"}}]}]}\n"
+    form:
+      alt:
+      - sse
+    headers:
+      Content-Type:
+      - application/json
+      User-Agent:
+      - google-genai-sdk/1.23.0 gl-go/go1.24.5
+    url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1