@@ -9,14 +9,12 @@ import (
"strings"
"time"
- "github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/logging"
"github.com/opencode-ai/opencode/internal/message"
- "google.golang.org/api/iterator"
- "google.golang.org/api/option"
+ "google.golang.org/genai"
)
type geminiOptions struct {
@@ -39,7 +37,7 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
o(&geminiOpts)
}
- client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
logging.Error("Failed to create Gemini client", "error", err)
return nil
@@ -57,11 +55,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
for _, msg := range messages {
switch msg.Role {
case message.User:
- var parts []genai.Part
- parts = append(parts, genai.Text(msg.Content().String()))
+ var parts []*genai.Part
+ parts = append(parts, &genai.Part{Text: msg.Content().String()})
for _, binaryContent := range msg.BinaryContent() {
imageFormat := strings.Split(binaryContent.MIMEType, "/")
- parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
+ parts = append(parts, &genai.Part{InlineData: &genai.Blob{
+ MIMEType: imageFormat[1],
+ Data: binaryContent.Data,
+ }})
}
history = append(history, &genai.Content{
Parts: parts,
@@ -70,19 +71,21 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
case message.Assistant:
content := &genai.Content{
Role: "model",
- Parts: []genai.Part{},
+ Parts: []*genai.Part{},
}
if msg.Content().String() != "" {
- content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
+ content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
}
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
- content.Parts = append(content.Parts, genai.FunctionCall{
- Name: call.Name,
- Args: args,
+ content.Parts = append(content.Parts, &genai.Part{
+ FunctionCall: &genai.FunctionCall{
+ Name: call.Name,
+ Args: args,
+ },
})
}
}
@@ -110,10 +113,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
}
history = append(history, &genai.Content{
- Parts: []genai.Part{genai.FunctionResponse{
- Name: toolCall.Name,
- Response: response,
- }},
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ Name: toolCall.Name,
+ Response: response,
+ },
+ },
+ },
Role: "function",
})
}
@@ -157,18 +164,6 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
}
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
- model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
- model.SystemInstruction = &genai.Content{
- Parts: []genai.Part{
- genai.Text(g.providerOptions.systemMessage),
- },
- }
- // Convert tools
- if len(tools) > 0 {
- model.Tools = g.convertTools(tools)
- }
-
// Convert messages
geminiMessages := g.convertMessages(messages)
@@ -178,16 +173,26 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
logging.Debug("Prepared messages", "messages", string(jsonData))
}
+ history := geminiMessages[:len(geminiMessages)-1] // All but last message
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
+ MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ SystemInstruction: &genai.Content{
+ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+ },
+ Tools: g.convertTools(tools),
+ }, history)
+
attempts := 0
for {
attempts++
var toolCalls []message.ToolCall
- chat := model.StartChat()
- chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
-
- lastMsg := geminiMessages[len(geminiMessages)-1]
- resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
+ var lastMsgParts []genai.Part
+ for _, part := range lastMsg.Parts {
+ lastMsgParts = append(lastMsgParts, *part)
+ }
+ resp, err := chat.SendMessage(ctx, lastMsgParts...)
// If there is an error we are going to see if we can retry the call
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -210,15 +215,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- content = string(p)
- case genai.FunctionCall:
+ switch {
+ case part.Text != "":
+ content = string(part.Text)
+ case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
+ args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
- Name: p.Name,
+ Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -244,18 +249,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
}
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
- model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
- model.SystemInstruction = &genai.Content{
- Parts: []genai.Part{
- genai.Text(g.providerOptions.systemMessage),
- },
- }
- // Convert tools
- if len(tools) > 0 {
- model.Tools = g.convertTools(tools)
- }
-
// Convert messages
geminiMessages := g.convertMessages(messages)
@@ -265,6 +258,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
logging.Debug("Prepared messages", "messages", string(jsonData))
}
+ history := geminiMessages[:len(geminiMessages)-1] // All but last message
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
+ MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ SystemInstruction: &genai.Content{
+ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+ },
+ Tools: g.convertTools(tools),
+ }, history)
+
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -273,11 +276,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
for {
attempts++
- chat := model.StartChat()
- chat.History = geminiMessages[:len(geminiMessages)-1]
- lastMsg := geminiMessages[len(geminiMessages)-1]
-
- iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
currentContent := ""
toolCalls := []message.ToolCall{}
@@ -285,11 +283,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
eventChan <- ProviderEvent{Type: EventContentStart}
- for {
- resp, err := iter.Next()
- if err == iterator.Done {
- break
- }
+ var lastMsgParts []genai.Part
+
+ for _, part := range lastMsg.Parts {
+ lastMsgParts = append(lastMsgParts, *part)
+ }
+ for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
if retryErr != nil {
@@ -318,9 +317,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- delta := string(p)
+ switch {
+ case part.Text != "":
+ delta := string(part.Text)
if delta != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
@@ -328,12 +327,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
}
currentContent += delta
}
- case genai.FunctionCall:
+ case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
+ args, _ := json.Marshal(part.FunctionCall.Args)
newCall := message.ToolCall{
ID: id,
- Name: p.Name,
+ Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -421,12 +420,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
- if funcCall, ok := part.(genai.FunctionCall); ok {
+ if part.FunctionCall != nil {
id := "call_" + uuid.New().String()
- args, _ := json.Marshal(funcCall.Args)
+ args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
- Name: funcCall.Name,
+ Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
})