Cargo.lock 🔗
@@ -99,6 +99,7 @@ dependencies = [
"paths",
"picker",
"postage",
+ "pretty_assertions",
"project",
"prompt_store",
"proto",
Bennet Bo Fenner created
This adds some unit tests to ensure that the `update(...)`/migration
path to the latest versions works correctly
Release Notes:
- N/A
Cargo.lock | 1
crates/agent/Cargo.toml | 1
crates/agent/src/thread.rs | 10
crates/agent/src/thread_store.rs | 198 ++++++++++++++++++++++++++++++++-
4 files changed, 197 insertions(+), 13 deletions(-)
@@ -99,6 +99,7 @@ dependencies = [
"paths",
"picker",
"postage",
+ "pretty_assertions",
"project",
"prompt_store",
"proto",
@@ -109,5 +109,6 @@ gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
+pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
@@ -195,20 +195,20 @@ impl MessageSegment {
}
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>,
pub unsaved_buffer_paths: Vec<String>,
pub timestamp: DateTime<Utc>,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WorktreeSnapshot {
pub worktree_path: String,
pub git_state: Option<GitState>,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GitState {
pub remote_url: Option<String>,
pub head_sha: Option<String>,
@@ -247,7 +247,7 @@ impl LastRestoreCheckpoint {
}
}
-#[derive(Clone, Debug, Default, Serialize, Deserialize)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum DetailedSummaryState {
#[default]
NotGenerated,
@@ -391,7 +391,7 @@ impl ThreadSummary {
}
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ExceededWindowError {
/// Model used when last message exceeded context window
model_id: LanguageModelId,
@@ -603,7 +603,7 @@ pub struct SerializedThreadMetadata {
pub updated_at: DateTime<Utc>,
}
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct SerializedThread {
pub version: String,
pub summary: SharedString,
@@ -629,7 +629,7 @@ pub struct SerializedThread {
pub profile: Option<AgentProfileId>,
}
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,
@@ -690,11 +690,15 @@ impl SerializedThreadV0_1_0 {
messages.push(message);
}
- SerializedThread { messages, ..self.0 }
+ SerializedThread {
+ messages,
+ version: SerializedThread::VERSION.to_string(),
+ ..self.0
+ }
}
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedMessage {
pub id: MessageId,
pub role: Role,
@@ -712,7 +716,7 @@ pub struct SerializedMessage {
pub is_hidden: bool,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum SerializedMessageSegment {
#[serde(rename = "text")]
@@ -730,14 +734,14 @@ pub enum SerializedMessageSegment {
},
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub input: serde_json::Value,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
@@ -800,7 +804,7 @@ impl LegacySerializedMessage {
}
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedCrease {
pub start: usize,
pub end: usize,
@@ -1057,3 +1061,181 @@ impl ThreadsDatabase {
})
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::thread::{DetailedSummaryState, MessageId};
+ use chrono::Utc;
+ use language_model::{Role, TokenUsage};
+ use pretty_assertions::assert_eq;
+
+ #[test]
+ fn test_legacy_serialized_thread_upgrade() {
+ let updated_at = Utc::now();
+ let legacy_thread = LegacySerializedThread {
+ summary: "Test conversation".into(),
+ updated_at,
+ messages: vec![LegacySerializedMessage {
+ id: MessageId(1),
+ role: Role::User,
+ text: "Hello, world!".to_string(),
+ tool_uses: vec![],
+ tool_results: vec![],
+ }],
+ initial_project_snapshot: None,
+ };
+
+ let upgraded = legacy_thread.upgrade();
+
+ assert_eq!(
+ upgraded,
+ SerializedThread {
+ summary: "Test conversation".into(),
+ updated_at,
+ messages: vec![SerializedMessage {
+ id: MessageId(1),
+ role: Role::User,
+ segments: vec![SerializedMessageSegment::Text {
+ text: "Hello, world!".to_string()
+ }],
+ tool_uses: vec![],
+ tool_results: vec![],
+ context: "".to_string(),
+ creases: vec![],
+ is_hidden: false
+ }],
+ version: SerializedThread::VERSION.to_string(),
+ initial_project_snapshot: None,
+ cumulative_token_usage: TokenUsage::default(),
+ request_token_usage: vec![],
+ detailed_summary_state: DetailedSummaryState::default(),
+ exceeded_window_error: None,
+ model: None,
+ completion_mode: None,
+ tool_use_limit_reached: false,
+ profile: None
+ }
+ )
+ }
+
+ #[test]
+ fn test_serialized_threadv0_1_0_upgrade() {
+ let updated_at = Utc::now();
+ let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
+ summary: "Test conversation".into(),
+ updated_at,
+ messages: vec![
+ SerializedMessage {
+ id: MessageId(1),
+ role: Role::User,
+ segments: vec![SerializedMessageSegment::Text {
+ text: "Use tool_1".to_string(),
+ }],
+ tool_uses: vec![],
+ tool_results: vec![],
+ context: "".to_string(),
+ creases: vec![],
+ is_hidden: false,
+ },
+ SerializedMessage {
+ id: MessageId(2),
+ role: Role::Assistant,
+ segments: vec![SerializedMessageSegment::Text {
+ text: "I want to use a tool".to_string(),
+ }],
+ tool_uses: vec![SerializedToolUse {
+ id: "abc".into(),
+ name: "tool_1".into(),
+ input: serde_json::Value::Null,
+ }],
+ tool_results: vec![],
+ context: "".to_string(),
+ creases: vec![],
+ is_hidden: false,
+ },
+ SerializedMessage {
+ id: MessageId(1),
+ role: Role::User,
+ segments: vec![SerializedMessageSegment::Text {
+ text: "Here is the tool result".to_string(),
+ }],
+ tool_uses: vec![],
+ tool_results: vec![SerializedToolResult {
+ tool_use_id: "abc".into(),
+ is_error: false,
+ content: LanguageModelToolResultContent::Text("abcdef".into()),
+ output: Some(serde_json::Value::Null),
+ }],
+ context: "".to_string(),
+ creases: vec![],
+ is_hidden: false,
+ },
+ ],
+ version: SerializedThreadV0_1_0::VERSION.to_string(),
+ initial_project_snapshot: None,
+ cumulative_token_usage: TokenUsage::default(),
+ request_token_usage: vec![],
+ detailed_summary_state: DetailedSummaryState::default(),
+ exceeded_window_error: None,
+ model: None,
+ completion_mode: None,
+ tool_use_limit_reached: false,
+ profile: None,
+ });
+ let upgraded = thread_v0_1_0.upgrade();
+
+ assert_eq!(
+ upgraded,
+ SerializedThread {
+ summary: "Test conversation".into(),
+ updated_at,
+ messages: vec![
+ SerializedMessage {
+ id: MessageId(1),
+ role: Role::User,
+ segments: vec![SerializedMessageSegment::Text {
+ text: "Use tool_1".to_string()
+ }],
+ tool_uses: vec![],
+ tool_results: vec![],
+ context: "".to_string(),
+ creases: vec![],
+ is_hidden: false
+ },
+ SerializedMessage {
+ id: MessageId(2),
+ role: Role::Assistant,
+ segments: vec![SerializedMessageSegment::Text {
+ text: "I want to use a tool".to_string(),
+ }],
+ tool_uses: vec![SerializedToolUse {
+ id: "abc".into(),
+ name: "tool_1".into(),
+ input: serde_json::Value::Null,
+ }],
+ tool_results: vec![SerializedToolResult {
+ tool_use_id: "abc".into(),
+ is_error: false,
+ content: LanguageModelToolResultContent::Text("abcdef".into()),
+ output: Some(serde_json::Value::Null),
+ }],
+ context: "".to_string(),
+ creases: vec![],
+ is_hidden: false,
+ },
+ ],
+ version: SerializedThread::VERSION.to_string(),
+ initial_project_snapshot: None,
+ cumulative_token_usage: TokenUsage::default(),
+ request_token_usage: vec![],
+ detailed_summary_state: DetailedSummaryState::default(),
+ exceeded_window_error: None,
+ model: None,
+ completion_mode: None,
+ tool_use_limit_reached: false,
+ profile: None
+ }
+ )
+ }
+}