@@ -124,6 +124,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
ai.WithTools(a.tools...),
)
+ sessionLock := sync.Mutex{}
currentSession, err := a.sessions.Get(ctx, call.SessionID)
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
@@ -138,7 +139,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
// Generate title if first message
if len(msgs) == 0 {
wg.Go(func() {
- a.generateTitle(ctx, currentSession, call.Prompt)
+ sessionLock.Lock()
+ a.generateTitle(ctx, ¤tSession, call.Prompt)
+ sessionLock.Unlock()
})
}
@@ -304,8 +307,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
case ai.FinishReasonToolCalls:
finishReason = message.FinishReasonToolUse
}
+ slog.Info("OnStepFinish", "reason", stepResult.FinishReason)
currentAssistant.AddFinish(finishReason, "", "")
a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
+ sessionLock.Lock()
+ _, sessionErr := a.sessions.Save(genCtx, currentSession)
+ sessionLock.Unlock()
+ if sessionErr != nil {
+ return sessionErr
+ }
return a.messages.Update(genCtx, *currentAssistant)
},
})
@@ -525,7 +535,7 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S
return msgs, nil
}
-func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
+func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
if prompt == "" {
return
}
@@ -559,8 +569,8 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session session.Sessio
}
session.Title = title
- a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
- _, saveErr := a.sessions.Save(ctx, session)
+ a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
+ _, saveErr := a.sessions.Save(ctx, *session)
if saveErr != nil {
slog.Error("failed to save session title & usage", "error", saveErr)
return