@@ -54,19 +54,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
var history []*genai.Content
-
- // Add system message first
- history = append(history, &genai.Content{
- Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
- Role: "user",
- })
-
- // Add a system response to acknowledge the system message
- history = append(history, &genai.Content{
- Parts: []genai.Part{genai.Text("I'll help you with that.")},
- Role: "model",
- })
-
for _, msg := range messages {
switch msg.Role {
case message.User:
@@ -154,14 +141,11 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
}
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
- reasonStr := reason.String()
switch {
- case reasonStr == "STOP":
+ case reason == genai.FinishReasonStop:
return message.FinishReasonEndTurn
- case reasonStr == "MAX_TOKENS":
+ case reason == genai.FinishReasonMaxTokens:
return message.FinishReasonMaxTokens
- case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
- return message.FinishReasonToolUse
default:
return message.FinishReasonUnknown
}
@@ -170,7 +154,11 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
-
+ model.SystemInstruction = &genai.Content{
+ Parts: []genai.Part{
+ genai.Text(g.providerOptions.systemMessage),
+ },
+ }
// Convert tools
if len(tools) > 0 {
model.Tools = g.convertTools(tools)
@@ -188,19 +176,13 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
attempts := 0
for {
attempts++
+ var toolCalls []message.ToolCall
chat := model.StartChat()
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
- var lastText string
- for _, part := range lastMsg.Parts {
- if text, ok := part.(genai.Text); ok {
- lastText = string(text)
- break
- }
- }
- resp, err := chat.SendMessage(ctx, genai.Text(lastText))
+ resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
// If there is an error we are going to see if we can retry the call
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -220,7 +202,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
}
content := ""
- var toolCalls []message.ToolCall
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
@@ -231,20 +212,25 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: p.Name,
- Input: string(args),
- Type: "function",
+ ID: id,
+ Name: p.Name,
+ Input: string(args),
+ Type: "function",
+ Finished: true,
})
}
}
}
+ finishReason := g.finishReason(resp.Candidates[0].FinishReason)
+ if len(toolCalls) > 0 {
+ finishReason = message.FinishReasonToolUse
+ }
return &ProviderResponse{
Content: content,
ToolCalls: toolCalls,
Usage: g.usage(resp),
- FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
+ FinishReason: finishReason,
}, nil
}
}
@@ -252,7 +238,11 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
-
+ model.SystemInstruction = &genai.Content{
+ Parts: []genai.Part{
+ genai.Text(g.providerOptions.systemMessage),
+ },
+ }
// Convert tools
if len(tools) > 0 {
model.Tools = g.convertTools(tools)
@@ -276,18 +266,10 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
for {
attempts++
chat := model.StartChat()
- chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
-
+ chat.History = geminiMessages[:len(geminiMessages)-1]
lastMsg := geminiMessages[len(geminiMessages)-1]
- var lastText string
- for _, part := range lastMsg.Parts {
- if text, ok := part.(genai.Text); ok {
- lastText = string(text)
- break
- }
- }
- iter := chat.SendMessageStream(ctx, genai.Text(lastText))
+ iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
currentContent := ""
toolCalls := []message.ToolCall{}
@@ -330,23 +312,23 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
for _, part := range resp.Candidates[0].Content.Parts {
switch p := part.(type) {
case genai.Text:
- newText := string(p)
- delta := newText[len(currentContent):]
+ delta := string(p)
if delta != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
Content: delta,
}
- currentContent = newText
+ currentContent += delta
}
case genai.FunctionCall:
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
newCall := message.ToolCall{
- ID: id,
- Name: p.Name,
- Input: string(args),
- Type: "function",
+ ID: id,
+ Name: p.Name,
+ Input: string(args),
+ Type: "function",
+ Finished: true,
}
isNew := true
@@ -368,37 +350,22 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
eventChan <- ProviderEvent{Type: EventContentStop}
if finalResp != nil {
+ finishReason := g.finishReason(finalResp.Candidates[0].FinishReason)
+ if len(toolCalls) > 0 {
+ finishReason = message.FinishReasonToolUse
+ }
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: currentContent,
ToolCalls: toolCalls,
Usage: g.usage(finalResp),
- FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
+ FinishReason: finishReason,
},
}
return
}
- // If we get here, we need to retry
- if attempts > maxRetries {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
- }
- return
- }
-
- // Wait before retrying
- select {
- case <-ctx.Done():
- if ctx.Err() != nil {
- eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
- }
- return
- case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
- continue
- }
}
}()