@@ -3,6 +3,7 @@ package agent
import (
"context"
_ "embed"
+ "encoding/json"
"errors"
"fmt"
"log/slog"
@@ -177,7 +178,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
// Before each step create the new assistant message
PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
var assistantMsg message.Message
- assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
+ assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
Model: a.largeModel.ModelCfg.Model,
@@ -187,7 +188,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
return callContext, prepared, err
}
- callContext = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
+ callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
currentAssistant = &assistantMsg
@@ -200,7 +201,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
queuedCalls, _ := a.messageQueue.Get(call.SessionID)
a.messageQueue.Del(call.SessionID)
for _, queued := range queuedCalls {
- userMessage, createErr := a.createUserMessage(genCtx, queued)
+ userMessage, createErr := a.createUserMessage(callContext, queued)
if createErr != nil {
return callContext, prepared, createErr
}
@@ -291,12 +292,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
IsError: isError,
Metadata: result.ClientMetadata,
}
- a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
+ _, err := a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: []message.ContentPart{
toolResult,
},
})
+ if err != nil {
+ return err
+ }
return a.messages.Update(genCtx, *currentAssistant)
},
OnStepFinish: func(stepResult ai.StepResult) error {
@@ -328,16 +332,29 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
}
toolCalls := currentAssistant.ToolCalls()
toolResults := currentAssistant.ToolResults()
+ // INFO: we use the parent context here because the genCtx has been cancelled
+ msgs, createErr := a.messages.List(ctx, currentSession.ID)
+ if createErr != nil {
+ return nil, createErr
+ }
for _, tc := range toolCalls {
if !tc.Finished {
tc.Finished = true
tc.Input = "{}"
}
currentAssistant.AddToolCall(tc)
+
found := false
- for _, tr := range toolResults {
- if tr.ToolCallID == tc.ID {
- found = true
+ for _, msg := range msgs {
+ if msg.Role == message.Tool {
+ for _, tr := range toolResults {
+ if tr.ToolCallID == tc.ID {
+ found = true
+ break
+ }
+ }
+ }
+ if found {
break
}
}
@@ -348,12 +365,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
} else if isPermissionErr {
content = "Permission denied"
}
- currentAssistant.AddToolResult(message.ToolResult{
+ toolResult := message.ToolResult{
ToolCallID: tc.ID,
Name: tc.Name,
Content: content,
IsError: true,
+ }
+ _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
+ Role: message.Tool,
+ Parts: []message.ContentPart{
+ toolResult,
+ },
})
+ if createErr != nil {
+ return nil, createErr
+ }
}
}
if isCancelErr {
@@ -363,13 +389,11 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
} else {
currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
}
- // INFO: we use the parent context here because the genCtx might have been cancelled
+ // INFO: we use the parent context here because the genCtx has been cancelled
updateErr := a.messages.Update(ctx, *currentAssistant)
if updateErr != nil {
return nil, updateErr
}
- }
- if err != nil {
return nil, err
}
wg.Wait()
@@ -541,22 +565,32 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
return
}
+ var maxOutput int64 = 40
+ if a.smallModel.CatwalkCfg.CanReason {
+ maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
+ }
+
agent := ai.NewAgent(a.smallModel.Model,
- ai.WithSystemPrompt(string(titlePrompt)),
- ai.WithMaxOutputTokens(40),
+ ai.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
+ ai.WithMaxOutputTokens(maxOutput),
)
resp, err := agent.Stream(ctx, ai.AgentStreamCall{
- Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
+ Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
})
if err != nil {
slog.Error("error generating title", "err", err)
return
}
+ data, _ := json.Marshal(resp)
+
+ slog.Info("Title Response")
+ slog.Info(string(data))
title := resp.Response.Content.Text()
title = strings.ReplaceAll(title, "\n", " ")
+ slog.Info(title)
// remove thinking tags if present
if idx := strings.Index(title, "</think>"); idx > 0 {
@@ -2,6 +2,7 @@ package chat
import (
"context"
+ "errors"
"fmt"
"time"
@@ -751,6 +752,11 @@ func (p *chatPage) sendMessage(text string, attachments []message.Attachment) te
cmds = append(cmds, func() tea.Msg {
_, err := p.app.AgentCoordinator.Run(context.Background(), session.ID, text, attachments...)
if err != nil {
+ isCancelErr := errors.Is(err, context.Canceled)
+ isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
+ if isCancelErr || isPermissionErr {
+ return nil
+ }
return util.InfoMsg{
Type: util.InfoTypeError,
Msg: err.Error(),