handle errors correctly in the agent tool

Kujtim Hoxha created

Change summary

internal/llm/agent/agent-tool.go | 21 +++++++++++----------
1 file changed, 11 insertions(+), 10 deletions(-)

Detailed changes

internal/llm/agent/agent-tool.go 🔗

@@ -50,44 +50,45 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 
 	sessionID, messageID := tools.GetContextValues(ctx)
 	if sessionID == "" || messageID == "" {
-		return tools.NewTextErrorResponse("session ID and message ID are required"), nil
+		return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
 	}
 
 	agent, err := NewTaskAgent(b.lspClients)
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
 	}
 
 	session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
 	}
 
 	err = agent.Generate(ctx, session.ID, params.Prompt)
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
 	}
 
 	messages, err := b.messages.List(ctx, session.ID)
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err)
 	}
+
 	if len(messages) == 0 {
-		return tools.NewTextErrorResponse("no messages found"), nil
+		return tools.NewTextErrorResponse("no response"), nil
 	}
 
 	response := messages[len(messages)-1]
 	if response.Role != message.Assistant {
-		return tools.NewTextErrorResponse("no assistant message found"), nil
+		return tools.NewTextErrorResponse("no response"), nil
 	}
 
 	updatedSession, err := b.sessions.Get(ctx, session.ID)
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
 	}
 	parentSession, err := b.sessions.Get(ctx, sessionID)
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
 	}
 
 	parentSession.Cost += updatedSession.Cost
@@ -96,7 +97,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 
 	_, err = b.sessions.Save(ctx, parentSession)
 	if err != nil {
-		return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
+		return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
 	}
 	return tools.NewTextResponse(response.Content().String()), nil
 }