agent: Add tests for thread serialization code (#32383)

Bennet Bo Fenner created

This adds some unit tests to ensure that the `update(...)`/migration
path to the latest versions works correctly

Release Notes:

- N/A

Change summary

Cargo.lock                       |   1 
crates/agent/Cargo.toml          |   1 
crates/agent/src/thread.rs       |  10 
crates/agent/src/thread_store.rs | 198 ++++++++++++++++++++++++++++++++-
4 files changed, 197 insertions(+), 13 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -99,6 +99,7 @@ dependencies = [
  "paths",
  "picker",
  "postage",
+ "pretty_assertions",
  "project",
  "prompt_store",
  "proto",

crates/agent/Cargo.toml 🔗

@@ -109,5 +109,6 @@ gpui = { workspace = true, "features" = ["test-support"] }
 indoc.workspace = true
 language = { workspace = true, "features" = ["test-support"] }
 language_model = { workspace = true, "features" = ["test-support"] }
+pretty_assertions.workspace = true
 project = { workspace = true, features = ["test-support"] }
 rand.workspace = true

crates/agent/src/thread.rs 🔗

@@ -195,20 +195,20 @@ impl MessageSegment {
     }
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct ProjectSnapshot {
     pub worktree_snapshots: Vec<WorktreeSnapshot>,
     pub unsaved_buffer_paths: Vec<String>,
     pub timestamp: DateTime<Utc>,
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct WorktreeSnapshot {
     pub worktree_path: String,
     pub git_state: Option<GitState>,
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct GitState {
     pub remote_url: Option<String>,
     pub head_sha: Option<String>,
@@ -247,7 +247,7 @@ impl LastRestoreCheckpoint {
     }
 }
 
-#[derive(Clone, Debug, Default, Serialize, Deserialize)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 pub enum DetailedSummaryState {
     #[default]
     NotGenerated,
@@ -391,7 +391,7 @@ impl ThreadSummary {
     }
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct ExceededWindowError {
     /// Model used when last message exceeded context window
     model_id: LanguageModelId,

crates/agent/src/thread_store.rs 🔗

@@ -603,7 +603,7 @@ pub struct SerializedThreadMetadata {
     pub updated_at: DateTime<Utc>,
 }
 
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, PartialEq)]
 pub struct SerializedThread {
     pub version: String,
     pub summary: SharedString,
@@ -629,7 +629,7 @@ pub struct SerializedThread {
     pub profile: Option<AgentProfileId>,
 }
 
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, PartialEq)]
 pub struct SerializedLanguageModel {
     pub provider: String,
     pub model: String,
@@ -690,11 +690,15 @@ impl SerializedThreadV0_1_0 {
             messages.push(message);
         }
 
-        SerializedThread { messages, ..self.0 }
+        SerializedThread {
+            messages,
+            version: SerializedThread::VERSION.to_string(),
+            ..self.0
+        }
     }
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
 pub struct SerializedMessage {
     pub id: MessageId,
     pub role: Role,
@@ -712,7 +716,7 @@ pub struct SerializedMessage {
     pub is_hidden: bool,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
 #[serde(tag = "type")]
 pub enum SerializedMessageSegment {
     #[serde(rename = "text")]
@@ -730,14 +734,14 @@ pub enum SerializedMessageSegment {
     },
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
 pub struct SerializedToolUse {
     pub id: LanguageModelToolUseId,
     pub name: SharedString,
     pub input: serde_json::Value,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
 pub struct SerializedToolResult {
     pub tool_use_id: LanguageModelToolUseId,
     pub is_error: bool,
@@ -800,7 +804,7 @@ impl LegacySerializedMessage {
     }
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
 pub struct SerializedCrease {
     pub start: usize,
     pub end: usize,
@@ -1057,3 +1061,181 @@ impl ThreadsDatabase {
         })
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::thread::{DetailedSummaryState, MessageId};
+    use chrono::Utc;
+    use language_model::{Role, TokenUsage};
+    use pretty_assertions::assert_eq;
+
+    #[test]
+    fn test_legacy_serialized_thread_upgrade() {
+        let updated_at = Utc::now();
+        let legacy_thread = LegacySerializedThread {
+            summary: "Test conversation".into(),
+            updated_at,
+            messages: vec![LegacySerializedMessage {
+                id: MessageId(1),
+                role: Role::User,
+                text: "Hello, world!".to_string(),
+                tool_uses: vec![],
+                tool_results: vec![],
+            }],
+            initial_project_snapshot: None,
+        };
+
+        let upgraded = legacy_thread.upgrade();
+
+        assert_eq!(
+            upgraded,
+            SerializedThread {
+                summary: "Test conversation".into(),
+                updated_at,
+                messages: vec![SerializedMessage {
+                    id: MessageId(1),
+                    role: Role::User,
+                    segments: vec![SerializedMessageSegment::Text {
+                        text: "Hello, world!".to_string()
+                    }],
+                    tool_uses: vec![],
+                    tool_results: vec![],
+                    context: "".to_string(),
+                    creases: vec![],
+                    is_hidden: false
+                }],
+                version: SerializedThread::VERSION.to_string(),
+                initial_project_snapshot: None,
+                cumulative_token_usage: TokenUsage::default(),
+                request_token_usage: vec![],
+                detailed_summary_state: DetailedSummaryState::default(),
+                exceeded_window_error: None,
+                model: None,
+                completion_mode: None,
+                tool_use_limit_reached: false,
+                profile: None
+            }
+        )
+    }
+
+    #[test]
+    fn test_serialized_threadv0_1_0_upgrade() {
+        let updated_at = Utc::now();
+        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
+            summary: "Test conversation".into(),
+            updated_at,
+            messages: vec![
+                SerializedMessage {
+                    id: MessageId(1),
+                    role: Role::User,
+                    segments: vec![SerializedMessageSegment::Text {
+                        text: "Use tool_1".to_string(),
+                    }],
+                    tool_uses: vec![],
+                    tool_results: vec![],
+                    context: "".to_string(),
+                    creases: vec![],
+                    is_hidden: false,
+                },
+                SerializedMessage {
+                    id: MessageId(2),
+                    role: Role::Assistant,
+                    segments: vec![SerializedMessageSegment::Text {
+                        text: "I want to use a tool".to_string(),
+                    }],
+                    tool_uses: vec![SerializedToolUse {
+                        id: "abc".into(),
+                        name: "tool_1".into(),
+                        input: serde_json::Value::Null,
+                    }],
+                    tool_results: vec![],
+                    context: "".to_string(),
+                    creases: vec![],
+                    is_hidden: false,
+                },
+                SerializedMessage {
+                    id: MessageId(1),
+                    role: Role::User,
+                    segments: vec![SerializedMessageSegment::Text {
+                        text: "Here is the tool result".to_string(),
+                    }],
+                    tool_uses: vec![],
+                    tool_results: vec![SerializedToolResult {
+                        tool_use_id: "abc".into(),
+                        is_error: false,
+                        content: LanguageModelToolResultContent::Text("abcdef".into()),
+                        output: Some(serde_json::Value::Null),
+                    }],
+                    context: "".to_string(),
+                    creases: vec![],
+                    is_hidden: false,
+                },
+            ],
+            version: SerializedThreadV0_1_0::VERSION.to_string(),
+            initial_project_snapshot: None,
+            cumulative_token_usage: TokenUsage::default(),
+            request_token_usage: vec![],
+            detailed_summary_state: DetailedSummaryState::default(),
+            exceeded_window_error: None,
+            model: None,
+            completion_mode: None,
+            tool_use_limit_reached: false,
+            profile: None,
+        });
+        let upgraded = thread_v0_1_0.upgrade();
+
+        assert_eq!(
+            upgraded,
+            SerializedThread {
+                summary: "Test conversation".into(),
+                updated_at,
+                messages: vec![
+                    SerializedMessage {
+                        id: MessageId(1),
+                        role: Role::User,
+                        segments: vec![SerializedMessageSegment::Text {
+                            text: "Use tool_1".to_string()
+                        }],
+                        tool_uses: vec![],
+                        tool_results: vec![],
+                        context: "".to_string(),
+                        creases: vec![],
+                        is_hidden: false
+                    },
+                    SerializedMessage {
+                        id: MessageId(2),
+                        role: Role::Assistant,
+                        segments: vec![SerializedMessageSegment::Text {
+                            text: "I want to use a tool".to_string(),
+                        }],
+                        tool_uses: vec![SerializedToolUse {
+                            id: "abc".into(),
+                            name: "tool_1".into(),
+                            input: serde_json::Value::Null,
+                        }],
+                        tool_results: vec![SerializedToolResult {
+                            tool_use_id: "abc".into(),
+                            is_error: false,
+                            content: LanguageModelToolResultContent::Text("abcdef".into()),
+                            output: Some(serde_json::Value::Null),
+                        }],
+                        context: "".to_string(),
+                        creases: vec![],
+                        is_hidden: false,
+                    },
+                ],
+                version: SerializedThread::VERSION.to_string(),
+                initial_project_snapshot: None,
+                cumulative_token_usage: TokenUsage::default(),
+                request_token_usage: vec![],
+                detailed_summary_state: DetailedSummaryState::default(),
+                exceeded_window_error: None,
+                model: None,
+                completion_mode: None,
+                tool_use_limit_reached: false,
+                profile: None
+            }
+        )
+    }
+}