1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 pub parent_session_id: Option<acp::SessionId>,
30 #[serde(alias = "summary")]
31 pub title: SharedString,
32 pub updated_at: DateTime<Utc>,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36pub struct DbThread {
37 pub title: SharedString,
38 pub messages: Vec<DbMessage>,
39 pub updated_at: DateTime<Utc>,
40 #[serde(default)]
41 pub detailed_summary: Option<SharedString>,
42 #[serde(default)]
43 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
44 #[serde(default)]
45 pub cumulative_token_usage: language_model::TokenUsage,
46 #[serde(default)]
47 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
48 #[serde(default)]
49 pub model: Option<DbLanguageModel>,
50 #[serde(default)]
51 pub profile: Option<AgentProfileId>,
52 #[serde(default)]
53 pub imported: bool,
54 #[serde(default)]
55 pub subagent_context: Option<crate::SubagentContext>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct SharedThread {
60 pub title: SharedString,
61 pub messages: Vec<DbMessage>,
62 pub updated_at: DateTime<Utc>,
63 #[serde(default)]
64 pub model: Option<DbLanguageModel>,
65 pub version: String,
66}
67
68impl SharedThread {
69 pub const VERSION: &'static str = "1.0.0";
70
71 pub fn from_db_thread(thread: &DbThread) -> Self {
72 Self {
73 title: thread.title.clone(),
74 messages: thread.messages.clone(),
75 updated_at: thread.updated_at,
76 model: thread.model.clone(),
77 version: Self::VERSION.to_string(),
78 }
79 }
80
81 pub fn to_db_thread(self) -> DbThread {
82 DbThread {
83 title: format!("🔗 {}", self.title).into(),
84 messages: self.messages,
85 updated_at: self.updated_at,
86 detailed_summary: None,
87 initial_project_snapshot: None,
88 cumulative_token_usage: Default::default(),
89 request_token_usage: Default::default(),
90 model: self.model,
91 profile: None,
92 imported: true,
93 subagent_context: None,
94 }
95 }
96
97 pub fn to_bytes(&self) -> Result<Vec<u8>> {
98 const COMPRESSION_LEVEL: i32 = 3;
99 let json = serde_json::to_vec(self)?;
100 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
101 Ok(compressed)
102 }
103
104 pub fn from_bytes(data: &[u8]) -> Result<Self> {
105 let decompressed = zstd::decode_all(data)?;
106 Ok(serde_json::from_slice(&decompressed)?)
107 }
108}
109
110impl DbThread {
111 pub const VERSION: &'static str = "0.3.0";
112
113 pub fn from_json(json: &[u8]) -> Result<Self> {
114 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
115 match saved_thread_json.get("version") {
116 Some(serde_json::Value::String(version)) => match version.as_str() {
117 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
118 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
119 json,
120 )?),
121 },
122 _ => {
123 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
124 }
125 }
126 }
127
128 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
129 let mut messages = Vec::new();
130 let mut request_token_usage = HashMap::default();
131
132 let mut last_user_message_id = None;
133 for (ix, msg) in thread.messages.into_iter().enumerate() {
134 let message = match msg.role {
135 language_model::Role::User => {
136 let mut content = Vec::new();
137
138 // Convert segments to content
139 for segment in msg.segments {
140 match segment {
141 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
142 content.push(UserMessageContent::Text(text));
143 }
144 crate::legacy_thread::SerializedMessageSegment::Thinking {
145 text,
146 ..
147 } => {
148 // User messages don't have thinking segments, but handle gracefully
149 content.push(UserMessageContent::Text(text));
150 }
151 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
152 ..
153 } => {
154 // User messages don't have redacted thinking, skip.
155 }
156 }
157 }
158
159 // If no content was added, add context as text if available
160 if content.is_empty() && !msg.context.is_empty() {
161 content.push(UserMessageContent::Text(msg.context));
162 }
163
164 let id = UserMessageId::new();
165 last_user_message_id = Some(id.clone());
166
167 crate::Message::User(UserMessage {
168 // MessageId from old format can't be meaningfully converted, so generate a new one
169 id,
170 content,
171 })
172 }
173 language_model::Role::Assistant => {
174 let mut content = Vec::new();
175
176 // Convert segments to content
177 for segment in msg.segments {
178 match segment {
179 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
180 content.push(AgentMessageContent::Text(text));
181 }
182 crate::legacy_thread::SerializedMessageSegment::Thinking {
183 text,
184 signature,
185 } => {
186 content.push(AgentMessageContent::Thinking { text, signature });
187 }
188 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
189 data,
190 } => {
191 content.push(AgentMessageContent::RedactedThinking(data));
192 }
193 }
194 }
195
196 // Convert tool uses
197 let mut tool_names_by_id = HashMap::default();
198 for tool_use in msg.tool_uses {
199 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
200 content.push(AgentMessageContent::ToolUse(
201 language_model::LanguageModelToolUse {
202 id: tool_use.id,
203 name: tool_use.name.into(),
204 raw_input: serde_json::to_string(&tool_use.input)
205 .unwrap_or_default(),
206 input: tool_use.input,
207 is_input_complete: true,
208 thought_signature: None,
209 },
210 ));
211 }
212
213 // Convert tool results
214 let mut tool_results = IndexMap::default();
215 for tool_result in msg.tool_results {
216 let name = tool_names_by_id
217 .remove(&tool_result.tool_use_id)
218 .unwrap_or_else(|| SharedString::from("unknown"));
219 tool_results.insert(
220 tool_result.tool_use_id.clone(),
221 language_model::LanguageModelToolResult {
222 tool_use_id: tool_result.tool_use_id,
223 tool_name: name.into(),
224 is_error: tool_result.is_error,
225 content: tool_result.content,
226 output: tool_result.output,
227 },
228 );
229 }
230
231 if let Some(last_user_message_id) = &last_user_message_id
232 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
233 {
234 request_token_usage.insert(last_user_message_id.clone(), token_usage);
235 }
236
237 crate::Message::Agent(AgentMessage {
238 content,
239 tool_results,
240 reasoning_details: None,
241 })
242 }
243 language_model::Role::System => {
244 // Skip system messages as they're not supported in the new format
245 continue;
246 }
247 };
248
249 messages.push(message);
250 }
251
252 Ok(Self {
253 title: thread.summary,
254 messages,
255 updated_at: thread.updated_at,
256 detailed_summary: match thread.detailed_summary_state {
257 crate::legacy_thread::DetailedSummaryState::NotGenerated
258 | crate::legacy_thread::DetailedSummaryState::Generating => None,
259 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
260 },
261 initial_project_snapshot: thread.initial_project_snapshot,
262 cumulative_token_usage: thread.cumulative_token_usage,
263 request_token_usage,
264 model: thread.model,
265 profile: thread.profile,
266 imported: false,
267 subagent_context: None,
268 })
269 }
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
273pub enum DataType {
274 #[serde(rename = "json")]
275 Json,
276 #[serde(rename = "zstd")]
277 Zstd,
278}
279
280impl Bind for DataType {
281 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
282 let value = match self {
283 DataType::Json => "json",
284 DataType::Zstd => "zstd",
285 };
286 value.bind(statement, start_index)
287 }
288}
289
290impl Column for DataType {
291 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
292 let (value, next_index) = String::column(statement, start_index)?;
293 let data_type = match value.as_str() {
294 "json" => DataType::Json,
295 "zstd" => DataType::Zstd,
296 _ => anyhow::bail!("Unknown data type: {}", value),
297 };
298 Ok((data_type, next_index))
299 }
300}
301
302pub(crate) struct ThreadsDatabase {
303 executor: BackgroundExecutor,
304 connection: Arc<Mutex<Connection>>,
305}
306
307struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
308
309impl Global for GlobalThreadsDatabase {}
310
311impl ThreadsDatabase {
312 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
313 if cx.has_global::<GlobalThreadsDatabase>() {
314 return cx.global::<GlobalThreadsDatabase>().0.clone();
315 }
316 let executor = cx.background_executor().clone();
317 let task = executor
318 .spawn({
319 let executor = executor.clone();
320 async move {
321 match ThreadsDatabase::new(executor) {
322 Ok(db) => Ok(Arc::new(db)),
323 Err(err) => Err(Arc::new(err)),
324 }
325 }
326 })
327 .shared();
328
329 cx.set_global(GlobalThreadsDatabase(task.clone()));
330 task
331 }
332
333 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
334 let connection = if *ZED_STATELESS {
335 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
336 } else if cfg!(any(feature = "test-support", test)) {
337 // rust stores the name of the test on the current thread.
338 // We use this to automatically create a database that will
339 // be shared within the test (for the test_retrieve_old_thread)
340 // but not with concurrent tests.
341 let thread = std::thread::current();
342 let test_name = thread.name();
343 Connection::open_memory(Some(&format!(
344 "THREAD_FALLBACK_{}",
345 test_name.unwrap_or_default()
346 )))
347 } else {
348 let threads_dir = paths::data_dir().join("threads");
349 std::fs::create_dir_all(&threads_dir)?;
350 let sqlite_path = threads_dir.join("threads.db");
351 Connection::open_file(&sqlite_path.to_string_lossy())
352 };
353
354 connection.exec(indoc! {"
355 CREATE TABLE IF NOT EXISTS threads (
356 id TEXT PRIMARY KEY,
357 summary TEXT NOT NULL,
358 updated_at TEXT NOT NULL,
359 data_type TEXT NOT NULL,
360 data BLOB NOT NULL
361 )
362 "})?()
363 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
364
365 if let Ok(mut s) = connection.exec(indoc! {"
366 ALTER TABLE threads ADD COLUMN parent_id TEXT
367 "})
368 {
369 s().ok();
370 }
371
372 let db = Self {
373 executor,
374 connection: Arc::new(Mutex::new(connection)),
375 };
376
377 Ok(db)
378 }
379
380 fn save_thread_sync(
381 connection: &Arc<Mutex<Connection>>,
382 id: acp::SessionId,
383 thread: DbThread,
384 ) -> Result<()> {
385 const COMPRESSION_LEVEL: i32 = 3;
386
387 #[derive(Serialize)]
388 struct SerializedThread {
389 #[serde(flatten)]
390 thread: DbThread,
391 version: &'static str,
392 }
393
394 let title = thread.title.to_string();
395 let updated_at = thread.updated_at.to_rfc3339();
396 let parent_id = thread
397 .subagent_context
398 .as_ref()
399 .map(|ctx| ctx.parent_thread_id.0.clone());
400 let json_data = serde_json::to_string(&SerializedThread {
401 thread,
402 version: DbThread::VERSION,
403 })?;
404
405 let connection = connection.lock();
406
407 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
408 let data_type = DataType::Zstd;
409 let data = compressed;
410
411 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, String, String, DataType, Vec<u8>)>(indoc! {"
412 INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?)
413 "})?;
414
415 insert((id.0, parent_id, title, updated_at, data_type, data))?;
416
417 Ok(())
418 }
419
420 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
421 let connection = self.connection.clone();
422
423 self.executor.spawn(async move {
424 let connection = connection.lock();
425
426 let mut select = connection
427 .select_bound::<(), (Arc<str>, Option<Arc<str>>, String, String)>(indoc! {"
428 SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC
429 "})?;
430
431 let rows = select(())?;
432 let mut threads = Vec::new();
433
434 for (id, parent_id, summary, updated_at) in rows {
435 threads.push(DbThreadMetadata {
436 id: acp::SessionId::new(id),
437 parent_session_id: parent_id.map(acp::SessionId::new),
438 title: summary.into(),
439 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
440 });
441 }
442
443 Ok(threads)
444 })
445 }
446
447 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
448 let connection = self.connection.clone();
449
450 self.executor.spawn(async move {
451 let connection = connection.lock();
452 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
453 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
454 "})?;
455
456 let rows = select(id.0)?;
457 if let Some((data_type, data)) = rows.into_iter().next() {
458 let json_data = match data_type {
459 DataType::Zstd => {
460 let decompressed = zstd::decode_all(&data[..])?;
461 String::from_utf8(decompressed)?
462 }
463 DataType::Json => String::from_utf8(data)?,
464 };
465 let thread = DbThread::from_json(json_data.as_bytes())?;
466 Ok(Some(thread))
467 } else {
468 Ok(None)
469 }
470 })
471 }
472
473 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
474 let connection = self.connection.clone();
475
476 self.executor
477 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
478 }
479
480 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
481 let connection = self.connection.clone();
482
483 self.executor.spawn(async move {
484 let connection = connection.lock();
485
486 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
487 DELETE FROM threads WHERE id = ?
488 "})?;
489
490 delete(id.0)?;
491
492 Ok(())
493 })
494 }
495
496 pub fn delete_threads(&self) -> Task<Result<()>> {
497 let connection = self.connection.clone();
498
499 self.executor.spawn(async move {
500 let connection = connection.lock();
501
502 let mut delete = connection.exec_bound::<()>(indoc! {"
503 DELETE FROM threads
504 "})?;
505
506 delete(())?;
507
508 Ok(())
509 })
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use chrono::{DateTime, TimeZone, Utc};
517 use collections::HashMap;
518 use gpui::TestAppContext;
519 use std::sync::Arc;
520
521 #[test]
522 fn test_shared_thread_roundtrip() {
523 let original = SharedThread {
524 title: "Test Thread".into(),
525 messages: vec![],
526 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
527 model: None,
528 version: SharedThread::VERSION.to_string(),
529 };
530
531 let bytes = original.to_bytes().expect("Failed to serialize");
532 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
533
534 assert_eq!(restored.title, original.title);
535 assert_eq!(restored.version, original.version);
536 assert_eq!(restored.updated_at, original.updated_at);
537 }
538
539 #[test]
540 fn test_imported_flag_defaults_to_false() {
541 // Simulate deserializing a thread without the imported field (backwards compatibility).
542 let json = r#"{
543 "title": "Old Thread",
544 "messages": [],
545 "updated_at": "2024-01-01T00:00:00Z"
546 }"#;
547
548 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
549
550 assert!(
551 !db_thread.imported,
552 "Legacy threads without imported field should default to false"
553 );
554 }
555
556 fn session_id(value: &str) -> acp::SessionId {
557 acp::SessionId::new(Arc::<str>::from(value))
558 }
559
560 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
561 DbThread {
562 title: title.to_string().into(),
563 messages: Vec::new(),
564 updated_at,
565 detailed_summary: None,
566 initial_project_snapshot: None,
567 cumulative_token_usage: Default::default(),
568 request_token_usage: HashMap::default(),
569 model: None,
570 profile: None,
571 imported: false,
572 subagent_context: None,
573 }
574 }
575
576 #[gpui::test]
577 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
578 let database = ThreadsDatabase::new(cx.executor()).unwrap();
579
580 let older_id = session_id("thread-a");
581 let newer_id = session_id("thread-b");
582
583 let older_thread = make_thread(
584 "Thread A",
585 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
586 );
587 let newer_thread = make_thread(
588 "Thread B",
589 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
590 );
591
592 database
593 .save_thread(older_id.clone(), older_thread)
594 .await
595 .unwrap();
596 database
597 .save_thread(newer_id.clone(), newer_thread)
598 .await
599 .unwrap();
600
601 let entries = database.list_threads().await.unwrap();
602 assert_eq!(entries.len(), 2);
603 assert_eq!(entries[0].id, newer_id);
604 assert_eq!(entries[1].id, older_id);
605 }
606
607 #[gpui::test]
608 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
609 let database = ThreadsDatabase::new(cx.executor()).unwrap();
610
611 let thread_id = session_id("thread-a");
612 let original_thread = make_thread(
613 "Thread A",
614 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
615 );
616 let updated_thread = make_thread(
617 "Thread B",
618 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
619 );
620
621 database
622 .save_thread(thread_id.clone(), original_thread)
623 .await
624 .unwrap();
625 database
626 .save_thread(thread_id.clone(), updated_thread)
627 .await
628 .unwrap();
629
630 let entries = database.list_threads().await.unwrap();
631 assert_eq!(entries.len(), 1);
632 assert_eq!(entries[0].id, thread_id);
633 assert_eq!(entries[0].title.as_ref(), "Thread B");
634 assert_eq!(
635 entries[0].updated_at,
636 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
637 );
638 }
639
640 #[test]
641 fn test_subagent_context_defaults_to_none() {
642 let json = r#"{
643 "title": "Old Thread",
644 "messages": [],
645 "updated_at": "2024-01-01T00:00:00Z"
646 }"#;
647
648 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
649
650 assert!(
651 db_thread.subagent_context.is_none(),
652 "Legacy threads without subagent_context should default to None"
653 );
654 }
655
656 #[gpui::test]
657 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
658 let database = ThreadsDatabase::new(cx.executor()).unwrap();
659
660 let parent_id = session_id("parent-thread");
661 let child_id = session_id("child-thread");
662
663 let mut child_thread = make_thread(
664 "Subagent Thread",
665 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
666 );
667 child_thread.subagent_context = Some(crate::SubagentContext {
668 parent_thread_id: parent_id.clone(),
669 depth: 2,
670 });
671
672 database
673 .save_thread(child_id.clone(), child_thread)
674 .await
675 .unwrap();
676
677 let loaded = database
678 .load_thread(child_id)
679 .await
680 .unwrap()
681 .expect("thread should exist");
682
683 let context = loaded
684 .subagent_context
685 .expect("subagent_context should be restored");
686 assert_eq!(context.parent_thread_id, parent_id);
687 assert_eq!(context.depth, 2);
688 }
689
690 #[gpui::test]
691 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
692 let database = ThreadsDatabase::new(cx.executor()).unwrap();
693
694 let thread_id = session_id("regular-thread");
695 let thread = make_thread(
696 "Regular Thread",
697 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
698 );
699
700 database
701 .save_thread(thread_id.clone(), thread)
702 .await
703 .unwrap();
704
705 let loaded = database
706 .load_thread(thread_id)
707 .await
708 .unwrap()
709 .expect("thread should exist");
710
711 assert!(
712 loaded.subagent_context.is_none(),
713 "Regular threads should have no subagent_context"
714 );
715 }
716}