fix gemini provider

Kujtim Hoxha created

Change summary

internal/llm/provider/gemini.go | 109 ++++++++++++----------------------
1 file changed, 38 insertions(+), 71 deletions(-)

Detailed changes

internal/llm/provider/gemini.go 🔗

@@ -54,19 +54,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
 
 func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
 	var history []*genai.Content
-
-	// Add system message first
-	history = append(history, &genai.Content{
-		Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
-		Role:  "user",
-	})
-
-	// Add a system response to acknowledge the system message
-	history = append(history, &genai.Content{
-		Parts: []genai.Part{genai.Text("I'll help you with that.")},
-		Role:  "model",
-	})
-
 	for _, msg := range messages {
 		switch msg.Role {
 		case message.User:
@@ -154,14 +141,11 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
 }
 
 func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
-	reasonStr := reason.String()
 	switch {
-	case reasonStr == "STOP":
+	case reason == genai.FinishReasonStop:
 		return message.FinishReasonEndTurn
-	case reasonStr == "MAX_TOKENS":
+	case reason == genai.FinishReasonMaxTokens:
 		return message.FinishReasonMaxTokens
-	case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
-		return message.FinishReasonToolUse
 	default:
 		return message.FinishReasonUnknown
 	}
@@ -170,7 +154,11 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
 func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
 	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
-
+	model.SystemInstruction = &genai.Content{
+		Parts: []genai.Part{
+			genai.Text(g.providerOptions.systemMessage),
+		},
+	}
 	// Convert tools
 	if len(tools) > 0 {
 		model.Tools = g.convertTools(tools)
@@ -188,19 +176,13 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
 	attempts := 0
 	for {
 		attempts++
+		var toolCalls []message.ToolCall
 		chat := model.StartChat()
 		chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
 
 		lastMsg := geminiMessages[len(geminiMessages)-1]
-		var lastText string
-		for _, part := range lastMsg.Parts {
-			if text, ok := part.(genai.Text); ok {
-				lastText = string(text)
-				break
-			}
-		}
 
-		resp, err := chat.SendMessage(ctx, genai.Text(lastText))
+		resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
 		// If there is an error we are going to see if we can retry the call
 		if err != nil {
 			retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -220,7 +202,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
 		}
 
 		content := ""
-		var toolCalls []message.ToolCall
 
 		if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
 			for _, part := range resp.Candidates[0].Content.Parts {
@@ -231,20 +212,25 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
 					id := "call_" + uuid.New().String()
 					args, _ := json.Marshal(p.Args)
 					toolCalls = append(toolCalls, message.ToolCall{
-						ID:    id,
-						Name:  p.Name,
-						Input: string(args),
-						Type:  "function",
+						ID:       id,
+						Name:     p.Name,
+						Input:    string(args),
+						Type:     "function",
+						Finished: true,
 					})
 				}
 			}
 		}
+		finishReason := g.finishReason(resp.Candidates[0].FinishReason)
+		if len(toolCalls) > 0 {
+			finishReason = message.FinishReasonToolUse
+		}
 
 		return &ProviderResponse{
 			Content:      content,
 			ToolCalls:    toolCalls,
 			Usage:        g.usage(resp),
-			FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
+			FinishReason: finishReason,
 		}, nil
 	}
 }
@@ -252,7 +238,11 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
 func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
 	model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
 	model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
-
+	model.SystemInstruction = &genai.Content{
+		Parts: []genai.Part{
+			genai.Text(g.providerOptions.systemMessage),
+		},
+	}
 	// Convert tools
 	if len(tools) > 0 {
 		model.Tools = g.convertTools(tools)
@@ -276,18 +266,10 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 		for {
 			attempts++
 			chat := model.StartChat()
-			chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
-
+			chat.History = geminiMessages[:len(geminiMessages)-1]
 			lastMsg := geminiMessages[len(geminiMessages)-1]
-			var lastText string
-			for _, part := range lastMsg.Parts {
-				if text, ok := part.(genai.Text); ok {
-					lastText = string(text)
-					break
-				}
-			}
 
-			iter := chat.SendMessageStream(ctx, genai.Text(lastText))
+			iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
 
 			currentContent := ""
 			toolCalls := []message.ToolCall{}
@@ -330,23 +312,23 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 					for _, part := range resp.Candidates[0].Content.Parts {
 						switch p := part.(type) {
 						case genai.Text:
-							newText := string(p)
-							delta := newText[len(currentContent):]
+							delta := string(p)
 							if delta != "" {
 								eventChan <- ProviderEvent{
 									Type:    EventContentDelta,
 									Content: delta,
 								}
-								currentContent = newText
+								currentContent += delta
 							}
 						case genai.FunctionCall:
 							id := "call_" + uuid.New().String()
 							args, _ := json.Marshal(p.Args)
 							newCall := message.ToolCall{
-								ID:    id,
-								Name:  p.Name,
-								Input: string(args),
-								Type:  "function",
+								ID:       id,
+								Name:     p.Name,
+								Input:    string(args),
+								Type:     "function",
+								Finished: true,
 							}
 
 							isNew := true
@@ -368,37 +350,22 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 			eventChan <- ProviderEvent{Type: EventContentStop}
 
 			if finalResp != nil {
+				finishReason := g.finishReason(finalResp.Candidates[0].FinishReason)
+				if len(toolCalls) > 0 {
+					finishReason = message.FinishReasonToolUse
+				}
 				eventChan <- ProviderEvent{
 					Type: EventComplete,
 					Response: &ProviderResponse{
 						Content:      currentContent,
 						ToolCalls:    toolCalls,
 						Usage:        g.usage(finalResp),
-						FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
+						FinishReason: finishReason,
 					},
 				}
 				return
 			}
 
-			// If we get here, we need to retry
-			if attempts > maxRetries {
-				eventChan <- ProviderEvent{
-					Type:  EventError,
-					Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
-				}
-				return
-			}
-
-			// Wait before retrying
-			select {
-			case <-ctx.Done():
-				if ctx.Err() != nil {
-					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
-				}
-				return
-			case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
-				continue
-			}
 		}
 	}()