1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use language_model::Speed;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use zed_env_vars::ZED_STATELESS;
22
23pub type DbMessage = crate::Message;
24pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
25pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct DbThreadMetadata {
29 pub id: acp::SessionId,
30 pub parent_session_id: Option<acp::SessionId>,
31 #[serde(alias = "summary")]
32 pub title: SharedString,
33 pub updated_at: DateTime<Utc>,
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37pub struct DbThread {
38 pub title: SharedString,
39 pub messages: Vec<DbMessage>,
40 pub updated_at: DateTime<Utc>,
41 #[serde(default)]
42 pub detailed_summary: Option<SharedString>,
43 #[serde(default)]
44 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
45 #[serde(default)]
46 pub cumulative_token_usage: language_model::TokenUsage,
47 #[serde(default)]
48 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
49 #[serde(default)]
50 pub model: Option<DbLanguageModel>,
51 #[serde(default)]
52 pub profile: Option<AgentProfileId>,
53 #[serde(default)]
54 pub imported: bool,
55 #[serde(default)]
56 pub subagent_context: Option<crate::SubagentContext>,
57 #[serde(default)]
58 pub speed: Option<Speed>,
59 #[serde(default)]
60 pub thinking_enabled: bool,
61 #[serde(default)]
62 pub thinking_effort: Option<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct SharedThread {
67 pub title: SharedString,
68 pub messages: Vec<DbMessage>,
69 pub updated_at: DateTime<Utc>,
70 #[serde(default)]
71 pub model: Option<DbLanguageModel>,
72 pub version: String,
73}
74
75impl SharedThread {
76 pub const VERSION: &'static str = "1.0.0";
77
78 pub fn from_db_thread(thread: &DbThread) -> Self {
79 Self {
80 title: thread.title.clone(),
81 messages: thread.messages.clone(),
82 updated_at: thread.updated_at,
83 model: thread.model.clone(),
84 version: Self::VERSION.to_string(),
85 }
86 }
87
88 pub fn to_db_thread(self) -> DbThread {
89 DbThread {
90 title: format!("🔗 {}", self.title).into(),
91 messages: self.messages,
92 updated_at: self.updated_at,
93 detailed_summary: None,
94 initial_project_snapshot: None,
95 cumulative_token_usage: Default::default(),
96 request_token_usage: Default::default(),
97 model: self.model,
98 profile: None,
99 imported: true,
100 subagent_context: None,
101 speed: None,
102 thinking_enabled: false,
103 thinking_effort: None,
104 }
105 }
106
107 pub fn to_bytes(&self) -> Result<Vec<u8>> {
108 const COMPRESSION_LEVEL: i32 = 3;
109 let json = serde_json::to_vec(self)?;
110 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
111 Ok(compressed)
112 }
113
114 pub fn from_bytes(data: &[u8]) -> Result<Self> {
115 let decompressed = zstd::decode_all(data)?;
116 Ok(serde_json::from_slice(&decompressed)?)
117 }
118}
119
120impl DbThread {
121 pub const VERSION: &'static str = "0.3.0";
122
123 pub fn from_json(json: &[u8]) -> Result<Self> {
124 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
125 match saved_thread_json.get("version") {
126 Some(serde_json::Value::String(version)) => match version.as_str() {
127 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
128 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
129 json,
130 )?),
131 },
132 _ => {
133 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
134 }
135 }
136 }
137
138 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
139 let mut messages = Vec::new();
140 let mut request_token_usage = HashMap::default();
141
142 let mut last_user_message_id = None;
143 for (ix, msg) in thread.messages.into_iter().enumerate() {
144 let message = match msg.role {
145 language_model::Role::User => {
146 let mut content = Vec::new();
147
148 // Convert segments to content
149 for segment in msg.segments {
150 match segment {
151 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
152 content.push(UserMessageContent::Text(text));
153 }
154 crate::legacy_thread::SerializedMessageSegment::Thinking {
155 text,
156 ..
157 } => {
158 // User messages don't have thinking segments, but handle gracefully
159 content.push(UserMessageContent::Text(text));
160 }
161 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
162 ..
163 } => {
164 // User messages don't have redacted thinking, skip.
165 }
166 }
167 }
168
169 // If no content was added, add context as text if available
170 if content.is_empty() && !msg.context.is_empty() {
171 content.push(UserMessageContent::Text(msg.context));
172 }
173
174 let id = UserMessageId::new();
175 last_user_message_id = Some(id.clone());
176
177 crate::Message::User(UserMessage {
178 // MessageId from old format can't be meaningfully converted, so generate a new one
179 id,
180 content,
181 })
182 }
183 language_model::Role::Assistant => {
184 let mut content = Vec::new();
185
186 // Convert segments to content
187 for segment in msg.segments {
188 match segment {
189 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
190 content.push(AgentMessageContent::Text(text));
191 }
192 crate::legacy_thread::SerializedMessageSegment::Thinking {
193 text,
194 signature,
195 } => {
196 content.push(AgentMessageContent::Thinking { text, signature });
197 }
198 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
199 data,
200 } => {
201 content.push(AgentMessageContent::RedactedThinking(data));
202 }
203 }
204 }
205
206 // Convert tool uses
207 let mut tool_names_by_id = HashMap::default();
208 for tool_use in msg.tool_uses {
209 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
210 content.push(AgentMessageContent::ToolUse(
211 language_model::LanguageModelToolUse {
212 id: tool_use.id,
213 name: tool_use.name.into(),
214 raw_input: serde_json::to_string(&tool_use.input)
215 .unwrap_or_default(),
216 input: tool_use.input,
217 is_input_complete: true,
218 thought_signature: None,
219 },
220 ));
221 }
222
223 // Convert tool results
224 let mut tool_results = IndexMap::default();
225 for tool_result in msg.tool_results {
226 let name = tool_names_by_id
227 .remove(&tool_result.tool_use_id)
228 .unwrap_or_else(|| SharedString::from("unknown"));
229 tool_results.insert(
230 tool_result.tool_use_id.clone(),
231 language_model::LanguageModelToolResult {
232 tool_use_id: tool_result.tool_use_id,
233 tool_name: name.into(),
234 is_error: tool_result.is_error,
235 content: tool_result.content,
236 output: tool_result.output,
237 },
238 );
239 }
240
241 if let Some(last_user_message_id) = &last_user_message_id
242 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
243 {
244 request_token_usage.insert(last_user_message_id.clone(), token_usage);
245 }
246
247 crate::Message::Agent(AgentMessage {
248 content,
249 tool_results,
250 reasoning_details: None,
251 })
252 }
253 language_model::Role::System => {
254 // Skip system messages as they're not supported in the new format
255 continue;
256 }
257 };
258
259 messages.push(message);
260 }
261
262 Ok(Self {
263 title: thread.summary,
264 messages,
265 updated_at: thread.updated_at,
266 detailed_summary: match thread.detailed_summary_state {
267 crate::legacy_thread::DetailedSummaryState::NotGenerated
268 | crate::legacy_thread::DetailedSummaryState::Generating => None,
269 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
270 },
271 initial_project_snapshot: thread.initial_project_snapshot,
272 cumulative_token_usage: thread.cumulative_token_usage,
273 request_token_usage,
274 model: thread.model,
275 profile: thread.profile,
276 imported: false,
277 subagent_context: None,
278 speed: None,
279 thinking_enabled: false,
280 thinking_effort: None,
281 })
282 }
283}
284
285#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
286pub enum DataType {
287 #[serde(rename = "json")]
288 Json,
289 #[serde(rename = "zstd")]
290 Zstd,
291}
292
293impl Bind for DataType {
294 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
295 let value = match self {
296 DataType::Json => "json",
297 DataType::Zstd => "zstd",
298 };
299 value.bind(statement, start_index)
300 }
301}
302
303impl Column for DataType {
304 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
305 let (value, next_index) = String::column(statement, start_index)?;
306 let data_type = match value.as_str() {
307 "json" => DataType::Json,
308 "zstd" => DataType::Zstd,
309 _ => anyhow::bail!("Unknown data type: {}", value),
310 };
311 Ok((data_type, next_index))
312 }
313}
314
315pub(crate) struct ThreadsDatabase {
316 executor: BackgroundExecutor,
317 connection: Arc<Mutex<Connection>>,
318}
319
320struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
321
322impl Global for GlobalThreadsDatabase {}
323
324impl ThreadsDatabase {
325 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
326 if cx.has_global::<GlobalThreadsDatabase>() {
327 return cx.global::<GlobalThreadsDatabase>().0.clone();
328 }
329 let executor = cx.background_executor().clone();
330 let task = executor
331 .spawn({
332 let executor = executor.clone();
333 async move {
334 match ThreadsDatabase::new(executor) {
335 Ok(db) => Ok(Arc::new(db)),
336 Err(err) => Err(Arc::new(err)),
337 }
338 }
339 })
340 .shared();
341
342 cx.set_global(GlobalThreadsDatabase(task.clone()));
343 task
344 }
345
346 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
347 let connection = if *ZED_STATELESS {
348 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
349 } else if cfg!(any(feature = "test-support", test)) {
350 // rust stores the name of the test on the current thread.
351 // We use this to automatically create a database that will
352 // be shared within the test (for the test_retrieve_old_thread)
353 // but not with concurrent tests.
354 let thread = std::thread::current();
355 let test_name = thread.name();
356 Connection::open_memory(Some(&format!(
357 "THREAD_FALLBACK_{}",
358 test_name.unwrap_or_default()
359 )))
360 } else {
361 let threads_dir = paths::data_dir().join("threads");
362 std::fs::create_dir_all(&threads_dir)?;
363 let sqlite_path = threads_dir.join("threads.db");
364 Connection::open_file(&sqlite_path.to_string_lossy())
365 };
366
367 connection.exec(indoc! {"
368 CREATE TABLE IF NOT EXISTS threads (
369 id TEXT PRIMARY KEY,
370 summary TEXT NOT NULL,
371 updated_at TEXT NOT NULL,
372 data_type TEXT NOT NULL,
373 data BLOB NOT NULL
374 )
375 "})?()
376 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
377
378 if let Ok(mut s) = connection.exec(indoc! {"
379 ALTER TABLE threads ADD COLUMN parent_id TEXT
380 "})
381 {
382 s().ok();
383 }
384
385 let db = Self {
386 executor,
387 connection: Arc::new(Mutex::new(connection)),
388 };
389
390 Ok(db)
391 }
392
393 fn save_thread_sync(
394 connection: &Arc<Mutex<Connection>>,
395 id: acp::SessionId,
396 thread: DbThread,
397 ) -> Result<()> {
398 const COMPRESSION_LEVEL: i32 = 3;
399
400 #[derive(Serialize)]
401 struct SerializedThread {
402 #[serde(flatten)]
403 thread: DbThread,
404 version: &'static str,
405 }
406
407 let title = thread.title.to_string();
408 let updated_at = thread.updated_at.to_rfc3339();
409 let parent_id = thread
410 .subagent_context
411 .as_ref()
412 .map(|ctx| ctx.parent_thread_id.0.clone());
413 let json_data = serde_json::to_string(&SerializedThread {
414 thread,
415 version: DbThread::VERSION,
416 })?;
417
418 let connection = connection.lock();
419
420 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
421 let data_type = DataType::Zstd;
422 let data = compressed;
423
424 let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, String, String, DataType, Vec<u8>)>(indoc! {"
425 INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?)
426 "})?;
427
428 insert((id.0, parent_id, title, updated_at, data_type, data))?;
429
430 Ok(())
431 }
432
433 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
434 let connection = self.connection.clone();
435
436 self.executor.spawn(async move {
437 let connection = connection.lock();
438
439 let mut select = connection
440 .select_bound::<(), (Arc<str>, Option<Arc<str>>, String, String)>(indoc! {"
441 SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC
442 "})?;
443
444 let rows = select(())?;
445 let mut threads = Vec::new();
446
447 for (id, parent_id, summary, updated_at) in rows {
448 threads.push(DbThreadMetadata {
449 id: acp::SessionId::new(id),
450 parent_session_id: parent_id.map(acp::SessionId::new),
451 title: summary.into(),
452 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
453 });
454 }
455
456 Ok(threads)
457 })
458 }
459
460 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
461 let connection = self.connection.clone();
462
463 self.executor.spawn(async move {
464 let connection = connection.lock();
465 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
466 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
467 "})?;
468
469 let rows = select(id.0)?;
470 if let Some((data_type, data)) = rows.into_iter().next() {
471 let json_data = match data_type {
472 DataType::Zstd => {
473 let decompressed = zstd::decode_all(&data[..])?;
474 String::from_utf8(decompressed)?
475 }
476 DataType::Json => String::from_utf8(data)?,
477 };
478 let thread = DbThread::from_json(json_data.as_bytes())?;
479 Ok(Some(thread))
480 } else {
481 Ok(None)
482 }
483 })
484 }
485
486 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
487 let connection = self.connection.clone();
488
489 self.executor
490 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
491 }
492
493 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
494 let connection = self.connection.clone();
495
496 self.executor.spawn(async move {
497 let connection = connection.lock();
498
499 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
500 DELETE FROM threads WHERE id = ?
501 "})?;
502
503 delete(id.0)?;
504
505 Ok(())
506 })
507 }
508
509 pub fn delete_threads(&self) -> Task<Result<()>> {
510 let connection = self.connection.clone();
511
512 self.executor.spawn(async move {
513 let connection = connection.lock();
514
515 let mut delete = connection.exec_bound::<()>(indoc! {"
516 DELETE FROM threads
517 "})?;
518
519 delete(())?;
520
521 Ok(())
522 })
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use chrono::{DateTime, TimeZone, Utc};
530 use collections::HashMap;
531 use gpui::TestAppContext;
532 use std::sync::Arc;
533
534 #[test]
535 fn test_shared_thread_roundtrip() {
536 let original = SharedThread {
537 title: "Test Thread".into(),
538 messages: vec![],
539 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
540 model: None,
541 version: SharedThread::VERSION.to_string(),
542 };
543
544 let bytes = original.to_bytes().expect("Failed to serialize");
545 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
546
547 assert_eq!(restored.title, original.title);
548 assert_eq!(restored.version, original.version);
549 assert_eq!(restored.updated_at, original.updated_at);
550 }
551
552 #[test]
553 fn test_imported_flag_defaults_to_false() {
554 // Simulate deserializing a thread without the imported field (backwards compatibility).
555 let json = r#"{
556 "title": "Old Thread",
557 "messages": [],
558 "updated_at": "2024-01-01T00:00:00Z"
559 }"#;
560
561 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
562
563 assert!(
564 !db_thread.imported,
565 "Legacy threads without imported field should default to false"
566 );
567 }
568
569 fn session_id(value: &str) -> acp::SessionId {
570 acp::SessionId::new(Arc::<str>::from(value))
571 }
572
573 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
574 DbThread {
575 title: title.to_string().into(),
576 messages: Vec::new(),
577 updated_at,
578 detailed_summary: None,
579 initial_project_snapshot: None,
580 cumulative_token_usage: Default::default(),
581 request_token_usage: HashMap::default(),
582 model: None,
583 profile: None,
584 imported: false,
585 subagent_context: None,
586 speed: None,
587 thinking_enabled: false,
588 thinking_effort: None,
589 }
590 }
591
592 #[gpui::test]
593 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
594 let database = ThreadsDatabase::new(cx.executor()).unwrap();
595
596 let older_id = session_id("thread-a");
597 let newer_id = session_id("thread-b");
598
599 let older_thread = make_thread(
600 "Thread A",
601 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
602 );
603 let newer_thread = make_thread(
604 "Thread B",
605 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
606 );
607
608 database
609 .save_thread(older_id.clone(), older_thread)
610 .await
611 .unwrap();
612 database
613 .save_thread(newer_id.clone(), newer_thread)
614 .await
615 .unwrap();
616
617 let entries = database.list_threads().await.unwrap();
618 assert_eq!(entries.len(), 2);
619 assert_eq!(entries[0].id, newer_id);
620 assert_eq!(entries[1].id, older_id);
621 }
622
623 #[gpui::test]
624 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
625 let database = ThreadsDatabase::new(cx.executor()).unwrap();
626
627 let thread_id = session_id("thread-a");
628 let original_thread = make_thread(
629 "Thread A",
630 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
631 );
632 let updated_thread = make_thread(
633 "Thread B",
634 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
635 );
636
637 database
638 .save_thread(thread_id.clone(), original_thread)
639 .await
640 .unwrap();
641 database
642 .save_thread(thread_id.clone(), updated_thread)
643 .await
644 .unwrap();
645
646 let entries = database.list_threads().await.unwrap();
647 assert_eq!(entries.len(), 1);
648 assert_eq!(entries[0].id, thread_id);
649 assert_eq!(entries[0].title.as_ref(), "Thread B");
650 assert_eq!(
651 entries[0].updated_at,
652 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
653 );
654 }
655
656 #[test]
657 fn test_subagent_context_defaults_to_none() {
658 let json = r#"{
659 "title": "Old Thread",
660 "messages": [],
661 "updated_at": "2024-01-01T00:00:00Z"
662 }"#;
663
664 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
665
666 assert!(
667 db_thread.subagent_context.is_none(),
668 "Legacy threads without subagent_context should default to None"
669 );
670 }
671
672 #[gpui::test]
673 async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
674 let database = ThreadsDatabase::new(cx.executor()).unwrap();
675
676 let parent_id = session_id("parent-thread");
677 let child_id = session_id("child-thread");
678
679 let mut child_thread = make_thread(
680 "Subagent Thread",
681 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
682 );
683 child_thread.subagent_context = Some(crate::SubagentContext {
684 parent_thread_id: parent_id.clone(),
685 depth: 2,
686 });
687
688 database
689 .save_thread(child_id.clone(), child_thread)
690 .await
691 .unwrap();
692
693 let loaded = database
694 .load_thread(child_id)
695 .await
696 .unwrap()
697 .expect("thread should exist");
698
699 let context = loaded
700 .subagent_context
701 .expect("subagent_context should be restored");
702 assert_eq!(context.parent_thread_id, parent_id);
703 assert_eq!(context.depth, 2);
704 }
705
706 #[gpui::test]
707 async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
708 let database = ThreadsDatabase::new(cx.executor()).unwrap();
709
710 let thread_id = session_id("regular-thread");
711 let thread = make_thread(
712 "Regular Thread",
713 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
714 );
715
716 database
717 .save_thread(thread_id.clone(), thread)
718 .await
719 .unwrap();
720
721 let loaded = database
722 .load_thread(thread_id)
723 .await
724 .unwrap()
725 .expect("thread should exist");
726
727 assert!(
728 loaded.subagent_context.is_none(),
729 "Regular threads should have no subagent_context"
730 );
731 }
732}