feat(gemini): fixes for streaming + thinking

Andrey Nering created

Change summary

google/google.go                                             | 124 +++++
providertests/testdata/TestTool/google-gemini-2.5-flash.yaml |   0 
providertests/testdata/TestTool/google-gemini-2.5-pro.yaml   |   0 
3 files changed, 105 insertions(+), 19 deletions(-)

Detailed changes

google/google.go 🔗

@@ -431,9 +431,13 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp
 		var currentContent string
 		var toolCalls []ai.ToolCallContent
 		var isActiveText bool
+		var isActiveReasoning bool
+		var blockCounter int
+		var currentTextBlockID string
+		var currentReasoningBlockID string
 		var usage ai.Usage
+		var lastFinishReason ai.FinishReason
 
-		// Stream the response
 		for resp, err := range chat.SendMessageStream(ctx, depointerSlice(lastMessage.Parts)...) {
 			if err != nil {
 				yield(ai.StreamPart{
@@ -449,30 +453,91 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp
 					case part.Text != "":
 						delta := part.Text
 						if delta != "" {
-							if !isActiveText {
-								isActiveText = true
+							// Check if this is a reasoning/thought part
+							if part.Thought {
+								// End any active text block before starting reasoning
+								if isActiveText {
+									isActiveText = false
+									if !yield(ai.StreamPart{
+										Type: ai.StreamPartTypeTextEnd,
+										ID:   currentTextBlockID,
+									}) {
+										return
+									}
+								}
+
+								// Start new reasoning block if not already active
+								if !isActiveReasoning {
+									isActiveReasoning = true
+									currentReasoningBlockID = fmt.Sprintf("%d", blockCounter)
+									blockCounter++
+									if !yield(ai.StreamPart{
+										Type: ai.StreamPartTypeReasoningStart,
+										ID:   currentReasoningBlockID,
+									}) {
+										return
+									}
+								}
+
 								if !yield(ai.StreamPart{
-									Type: ai.StreamPartTypeTextStart,
-									ID:   "0",
+									Type:  ai.StreamPartTypeReasoningDelta,
+									ID:    currentReasoningBlockID,
+									Delta: delta,
 								}) {
 									return
 								}
+							} else {
+								// Regular text part
+								// End any active reasoning block before starting text
+								if isActiveReasoning {
+									isActiveReasoning = false
+									if !yield(ai.StreamPart{
+										Type: ai.StreamPartTypeReasoningEnd,
+										ID:   currentReasoningBlockID,
+									}) {
+										return
+									}
+								}
+
+								// Start new text block if not already active
+								if !isActiveText {
+									isActiveText = true
+									currentTextBlockID = fmt.Sprintf("%d", blockCounter)
+									blockCounter++
+									if !yield(ai.StreamPart{
+										Type: ai.StreamPartTypeTextStart,
+										ID:   currentTextBlockID,
+									}) {
+										return
+									}
+								}
+
+								if !yield(ai.StreamPart{
+									Type:  ai.StreamPartTypeTextDelta,
+									ID:    currentTextBlockID,
+									Delta: delta,
+								}) {
+									return
+								}
+								currentContent += delta
 							}
-							if !yield(ai.StreamPart{
-								Type:  ai.StreamPartTypeTextDelta,
-								ID:    "0",
-								Delta: delta,
-							}) {
-								return
-							}
-							currentContent += delta
 						}
 					case part.FunctionCall != nil:
+						// End any active text or reasoning blocks
 						if isActiveText {
 							isActiveText = false
 							if !yield(ai.StreamPart{
 								Type: ai.StreamPartTypeTextEnd,
-								ID:   "0",
+								ID:   currentTextBlockID,
+							}) {
+								return
+							}
+						}
+						if isActiveReasoning {
+							isActiveReasoning = false
+							if !yield(ai.StreamPart{
+								Type: ai.StreamPartTypeReasoningEnd,
+								ID:   currentReasoningBlockID,
 							}) {
 								return
 							}
@@ -535,20 +600,35 @@ func (g *languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResp
 			if resp.UsageMetadata != nil {
 				usage = mapUsage(resp.UsageMetadata)
 			}
+
+			if len(resp.Candidates) > 0 && resp.Candidates[0].FinishReason != "" {
+				lastFinishReason = mapFinishReason(resp.Candidates[0].FinishReason)
+			}
 		}
 
+		// Close any open blocks before finishing
 		if isActiveText {
 			if !yield(ai.StreamPart{
 				Type: ai.StreamPartTypeTextEnd,
-				ID:   "0",
+				ID:   currentTextBlockID,
+			}) {
+				return
+			}
+		}
+		if isActiveReasoning {
+			if !yield(ai.StreamPart{
+				Type: ai.StreamPartTypeReasoningEnd,
+				ID:   currentReasoningBlockID,
 			}) {
 				return
 			}
 		}
 
-		finishReason := ai.FinishReasonStop
+		finishReason := lastFinishReason
 		if len(toolCalls) > 0 {
 			finishReason = ai.FinishReasonToolCalls
+		} else if finishReason == "" {
+			finishReason = ai.FinishReasonStop
 		}
 
 		yield(ai.StreamPart{
@@ -720,21 +800,27 @@ func mapResponse(response *genai.GenerateContentResponse, warnings []ai.CallWarn
 	for _, part := range candidate.Content.Parts {
 		switch {
 		case part.Text != "":
-			content = append(content, ai.TextContent{Text: part.Text})
+			if part.Thought {
+				content = append(content, ai.ReasoningContent{Text: part.Text})
+			} else {
+				content = append(content, ai.TextContent{Text: part.Text})
+			}
 		case part.FunctionCall != nil:
 			input, err := json.Marshal(part.FunctionCall.Args)
 			if err != nil {
 				return nil, err
 			}
+			toolCallID := cmp.Or(part.FunctionCall.ID, part.FunctionCall.Name, uuid.NewString())
 			content = append(content, ai.ToolCallContent{
-				ToolCallID:       part.FunctionCall.ID,
+				ToolCallID:       toolCallID,
 				ToolName:         part.FunctionCall.Name,
 				Input:            string(input),
 				ProviderExecuted: false,
 			})
 			hasToolCalls = true
 		default:
-			return nil, fmt.Errorf("not implemented part type")
+			// Silently skip unknown part types instead of erroring
+			// This allows for forward compatibility with new part types
 		}
 	}