@@ -1,15 +1,12 @@
package server
import (
+ "fmt"
"testing"
"shelley.exe.dev/db"
)
-func boolPtr(b bool) *bool {
- return &b
-}
-
func TestAgentWorking(t *testing.T) {
tests := []struct {
name string
@@ -24,14 +21,14 @@ func TestAgentWorking(t *testing.T) {
{
name: "agent with end_of_turn true",
messages: []APIMessage{
- {Type: string(db.MessageTypeAgent), EndOfTurn: boolPtr(true)},
+ {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr},
},
want: false,
},
{
name: "agent with end_of_turn false",
messages: []APIMessage{
- {Type: string(db.MessageTypeAgent), EndOfTurn: boolPtr(false)},
+ {Type: string(db.MessageTypeAgent), EndOfTurn: falsePtr},
},
want: true,
},
@@ -52,7 +49,7 @@ func TestAgentWorking(t *testing.T) {
{
name: "agent end_of_turn then tool message means working",
messages: []APIMessage{
- {Type: string(db.MessageTypeAgent), EndOfTurn: boolPtr(true)},
+ {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr},
{Type: string(db.MessageTypeTool)},
},
want: true,
@@ -60,7 +57,7 @@ func TestAgentWorking(t *testing.T) {
{
name: "gitinfo after agent end_of_turn should NOT indicate working",
messages: []APIMessage{
- {Type: string(db.MessageTypeAgent), EndOfTurn: boolPtr(true)},
+ {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr},
{Type: string(db.MessageTypeGitInfo)},
},
want: false,
@@ -68,7 +65,7 @@ func TestAgentWorking(t *testing.T) {
{
name: "multiple gitinfo after agent end_of_turn should NOT indicate working",
messages: []APIMessage{
- {Type: string(db.MessageTypeAgent), EndOfTurn: boolPtr(true)},
+ {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr},
{Type: string(db.MessageTypeGitInfo)},
{Type: string(db.MessageTypeGitInfo)},
},
@@ -77,7 +74,7 @@ func TestAgentWorking(t *testing.T) {
{
name: "gitinfo after agent not end_of_turn should indicate working",
messages: []APIMessage{
- {Type: string(db.MessageTypeAgent), EndOfTurn: boolPtr(false)},
+ {Type: string(db.MessageTypeAgent), EndOfTurn: falsePtr},
{Type: string(db.MessageTypeGitInfo)},
},
want: true,
@@ -95,8 +92,12 @@ func TestAgentWorking(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := agentWorking(tt.messages)
- if got != tt.want {
- t.Errorf("agentWorking() = %v, want %v", got, tt.want)
+ if got == nil || *got != tt.want {
+ gotVal := "nil"
+ if got != nil {
+ gotVal = fmt.Sprintf("%v", *got)
+ }
+ t.Errorf("agentWorking() = %v, want %v", gotVal, tt.want)
}
})
}
@@ -557,7 +557,7 @@ func (cm *ConversationManager) notifyGitStateChange(ctx context.Context, msg *ge
streamData := StreamResponse{
Messages: apiMessages,
Conversation: conversation,
- AgentWorking: false, // Gitinfo is recorded at end of turn, agent is done
+ AgentWorking: falsePtr, // Gitinfo is recorded at end of turn, agent is done
}
cm.subpub.Publish(msg.SequenceID, streamData)
}
@@ -45,7 +45,7 @@ type APIMessage struct {
type StreamResponse struct {
Messages []APIMessage `json:"messages"`
Conversation generated.Conversation `json:"conversation"`
- AgentWorking bool `json:"agent_working"`
+ AgentWorking *bool `json:"agent_working,omitempty"`
ContextWindowSize uint64 `json:"context_window_size,omitempty"`
// ConversationListUpdate is set when another conversation in the list changed
ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"`
@@ -149,9 +149,16 @@ func calculateContextWindowSize(messages []APIMessage) uint64 {
return 0
}
-func agentWorking(messages []APIMessage) bool {
+var (
+ truePtr = ptr(true)
+ falsePtr = ptr(false)
+)
+
+func ptr[T any](v T) *T { return &v }
+
+func agentWorking(messages []APIMessage) *bool {
if len(messages) == 0 {
- return false
+ return falsePtr
}
// Find the last non-gitinfo message (gitinfo messages are passive notifications)
@@ -160,20 +167,23 @@ func agentWorking(messages []APIMessage) bool {
lastIdx--
}
if lastIdx < 0 {
- return false
+ return falsePtr
}
last := messages[lastIdx]
// If the last message is an error, agent is not working
if last.Type == string(db.MessageTypeError) {
- return false
+ return falsePtr
}
if last.Type == string(db.MessageTypeAgent) {
if last.EndOfTurn == nil {
- return true
+ return truePtr
}
- return !*last.EndOfTurn
+ if *last.EndOfTurn {
+ return falsePtr
+ }
+ return truePtr
}
for i := lastIdx; i >= 0; i-- {
@@ -181,18 +191,12 @@ func agentWorking(messages []APIMessage) bool {
if msg.Type != string(db.MessageTypeAgent) {
continue
}
- if msg.EndOfTurn == nil {
- return true
- }
- if !*msg.EndOfTurn {
- return true
- }
// Agent ended turn, but newer non-agent messages exist, so agent is working again.
- return true
+ return truePtr
}
// No agent message found yet but conversation has activity, assume agent is working.
- return true
+ return truePtr
}
// isEndOfTurn checks if a database message represents end of turn
@@ -679,10 +683,14 @@ func (s *Server) notifySubscribersNewMessage(ctx context.Context, conversationID
apiMessages := toAPIMessages([]generated.Message{*newMsg})
// Publish only the new message
+ agentWorking := falsePtr
+ if !isEndOfTurn(newMsg) {
+ agentWorking = truePtr
+ }
streamData := StreamResponse{
Messages: apiMessages,
Conversation: conversation,
- AgentWorking: !isEndOfTurn(newMsg),
+ AgentWorking: agentWorking,
// ContextWindowSize: 0 for messages without usage data (user/tool messages).
// With omitempty, 0 is omitted from JSON, so the UI keeps its cached value.
// Only agent messages have usage data, so context window updates when they arrive.