1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::AgentProfileId;
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct DbThread {
36 pub title: SharedString,
37 pub messages: Vec<DbMessage>,
38 pub updated_at: DateTime<Utc>,
39 #[serde(default)]
40 pub detailed_summary: Option<SharedString>,
41 #[serde(default)]
42 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
43 #[serde(default)]
44 pub cumulative_token_usage: language_model::TokenUsage,
45 #[serde(default)]
46 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
47 #[serde(default)]
48 pub model: Option<DbLanguageModel>,
49 #[serde(default)]
50 pub profile: Option<AgentProfileId>,
51 #[serde(default)]
52 pub imported: bool,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SharedThread {
57 pub title: SharedString,
58 pub messages: Vec<DbMessage>,
59 pub updated_at: DateTime<Utc>,
60 #[serde(default)]
61 pub model: Option<DbLanguageModel>,
62 pub version: String,
63}
64
65impl SharedThread {
66 pub const VERSION: &'static str = "1.0.0";
67
68 pub fn from_db_thread(thread: &DbThread) -> Self {
69 Self {
70 title: thread.title.clone(),
71 messages: thread.messages.clone(),
72 updated_at: thread.updated_at,
73 model: thread.model.clone(),
74 version: Self::VERSION.to_string(),
75 }
76 }
77
78 pub fn to_db_thread(self) -> DbThread {
79 DbThread {
80 title: format!("🔗 {}", self.title).into(),
81 messages: self.messages,
82 updated_at: self.updated_at,
83 detailed_summary: None,
84 initial_project_snapshot: None,
85 cumulative_token_usage: Default::default(),
86 request_token_usage: Default::default(),
87 model: self.model,
88 profile: None,
89 imported: true,
90 }
91 }
92
93 pub fn to_bytes(&self) -> Result<Vec<u8>> {
94 const COMPRESSION_LEVEL: i32 = 3;
95 let json = serde_json::to_vec(self)?;
96 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
97 Ok(compressed)
98 }
99
100 pub fn from_bytes(data: &[u8]) -> Result<Self> {
101 let decompressed = zstd::decode_all(data)?;
102 Ok(serde_json::from_slice(&decompressed)?)
103 }
104}
105
106impl DbThread {
107 pub const VERSION: &'static str = "0.3.0";
108
109 pub fn from_json(json: &[u8]) -> Result<Self> {
110 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
111 match saved_thread_json.get("version") {
112 Some(serde_json::Value::String(version)) => match version.as_str() {
113 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
114 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
115 json,
116 )?),
117 },
118 _ => {
119 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
120 }
121 }
122 }
123
124 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
125 let mut messages = Vec::new();
126 let mut request_token_usage = HashMap::default();
127
128 let mut last_user_message_id = None;
129 for (ix, msg) in thread.messages.into_iter().enumerate() {
130 let message = match msg.role {
131 language_model::Role::User => {
132 let mut content = Vec::new();
133
134 // Convert segments to content
135 for segment in msg.segments {
136 match segment {
137 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
138 content.push(UserMessageContent::Text(text));
139 }
140 crate::legacy_thread::SerializedMessageSegment::Thinking {
141 text,
142 ..
143 } => {
144 // User messages don't have thinking segments, but handle gracefully
145 content.push(UserMessageContent::Text(text));
146 }
147 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
148 ..
149 } => {
150 // User messages don't have redacted thinking, skip.
151 }
152 }
153 }
154
155 // If no content was added, add context as text if available
156 if content.is_empty() && !msg.context.is_empty() {
157 content.push(UserMessageContent::Text(msg.context));
158 }
159
160 let id = UserMessageId::new();
161 last_user_message_id = Some(id.clone());
162
163 crate::Message::User(UserMessage {
164 // MessageId from old format can't be meaningfully converted, so generate a new one
165 id,
166 content,
167 })
168 }
169 language_model::Role::Assistant => {
170 let mut content = Vec::new();
171
172 // Convert segments to content
173 for segment in msg.segments {
174 match segment {
175 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
176 content.push(AgentMessageContent::Text(text));
177 }
178 crate::legacy_thread::SerializedMessageSegment::Thinking {
179 text,
180 signature,
181 } => {
182 content.push(AgentMessageContent::Thinking { text, signature });
183 }
184 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
185 data,
186 } => {
187 content.push(AgentMessageContent::RedactedThinking(data));
188 }
189 }
190 }
191
192 // Convert tool uses
193 let mut tool_names_by_id = HashMap::default();
194 for tool_use in msg.tool_uses {
195 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
196 content.push(AgentMessageContent::ToolUse(
197 language_model::LanguageModelToolUse {
198 id: tool_use.id,
199 name: tool_use.name.into(),
200 raw_input: serde_json::to_string(&tool_use.input)
201 .unwrap_or_default(),
202 input: tool_use.input,
203 is_input_complete: true,
204 thought_signature: None,
205 },
206 ));
207 }
208
209 // Convert tool results
210 let mut tool_results = IndexMap::default();
211 for tool_result in msg.tool_results {
212 let name = tool_names_by_id
213 .remove(&tool_result.tool_use_id)
214 .unwrap_or_else(|| SharedString::from("unknown"));
215 tool_results.insert(
216 tool_result.tool_use_id.clone(),
217 language_model::LanguageModelToolResult {
218 tool_use_id: tool_result.tool_use_id,
219 tool_name: name.into(),
220 is_error: tool_result.is_error,
221 content: tool_result.content,
222 output: tool_result.output,
223 },
224 );
225 }
226
227 if let Some(last_user_message_id) = &last_user_message_id
228 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
229 {
230 request_token_usage.insert(last_user_message_id.clone(), token_usage);
231 }
232
233 crate::Message::Agent(AgentMessage {
234 content,
235 tool_results,
236 reasoning_details: None,
237 })
238 }
239 language_model::Role::System => {
240 // Skip system messages as they're not supported in the new format
241 continue;
242 }
243 };
244
245 messages.push(message);
246 }
247
248 Ok(Self {
249 title: thread.summary,
250 messages,
251 updated_at: thread.updated_at,
252 detailed_summary: match thread.detailed_summary_state {
253 crate::legacy_thread::DetailedSummaryState::NotGenerated
254 | crate::legacy_thread::DetailedSummaryState::Generating => None,
255 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
256 },
257 initial_project_snapshot: thread.initial_project_snapshot,
258 cumulative_token_usage: thread.cumulative_token_usage,
259 request_token_usage,
260 model: thread.model,
261 profile: thread.profile,
262 imported: false,
263 })
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
268pub enum DataType {
269 #[serde(rename = "json")]
270 Json,
271 #[serde(rename = "zstd")]
272 Zstd,
273}
274
275impl Bind for DataType {
276 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
277 let value = match self {
278 DataType::Json => "json",
279 DataType::Zstd => "zstd",
280 };
281 value.bind(statement, start_index)
282 }
283}
284
285impl Column for DataType {
286 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
287 let (value, next_index) = String::column(statement, start_index)?;
288 let data_type = match value.as_str() {
289 "json" => DataType::Json,
290 "zstd" => DataType::Zstd,
291 _ => anyhow::bail!("Unknown data type: {}", value),
292 };
293 Ok((data_type, next_index))
294 }
295}
296
297pub(crate) struct ThreadsDatabase {
298 executor: BackgroundExecutor,
299 connection: Arc<Mutex<Connection>>,
300}
301
302struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
303
304impl Global for GlobalThreadsDatabase {}
305
306impl ThreadsDatabase {
307 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
308 if cx.has_global::<GlobalThreadsDatabase>() {
309 return cx.global::<GlobalThreadsDatabase>().0.clone();
310 }
311 let executor = cx.background_executor().clone();
312 let task = executor
313 .spawn({
314 let executor = executor.clone();
315 async move {
316 match ThreadsDatabase::new(executor) {
317 Ok(db) => Ok(Arc::new(db)),
318 Err(err) => Err(Arc::new(err)),
319 }
320 }
321 })
322 .shared();
323
324 cx.set_global(GlobalThreadsDatabase(task.clone()));
325 task
326 }
327
328 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
329 let connection = if *ZED_STATELESS {
330 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
331 } else if cfg!(any(feature = "test-support", test)) {
332 // rust stores the name of the test on the current thread.
333 // We use this to automatically create a database that will
334 // be shared within the test (for the test_retrieve_old_thread)
335 // but not with concurrent tests.
336 let thread = std::thread::current();
337 let test_name = thread.name();
338 Connection::open_memory(Some(&format!(
339 "THREAD_FALLBACK_{}",
340 test_name.unwrap_or_default()
341 )))
342 } else {
343 let threads_dir = paths::data_dir().join("threads");
344 std::fs::create_dir_all(&threads_dir)?;
345 let sqlite_path = threads_dir.join("threads.db");
346 Connection::open_file(&sqlite_path.to_string_lossy())
347 };
348
349 connection.exec(indoc! {"
350 CREATE TABLE IF NOT EXISTS threads (
351 id TEXT PRIMARY KEY,
352 summary TEXT NOT NULL,
353 updated_at TEXT NOT NULL,
354 data_type TEXT NOT NULL,
355 data BLOB NOT NULL
356 )
357 "})?()
358 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
359
360 let db = Self {
361 executor,
362 connection: Arc::new(Mutex::new(connection)),
363 };
364
365 Ok(db)
366 }
367
368 fn save_thread_sync(
369 connection: &Arc<Mutex<Connection>>,
370 id: acp::SessionId,
371 thread: DbThread,
372 ) -> Result<()> {
373 const COMPRESSION_LEVEL: i32 = 3;
374
375 #[derive(Serialize)]
376 struct SerializedThread {
377 #[serde(flatten)]
378 thread: DbThread,
379 version: &'static str,
380 }
381
382 let title = thread.title.to_string();
383 let updated_at = thread.updated_at.to_rfc3339();
384 let json_data = serde_json::to_string(&SerializedThread {
385 thread,
386 version: DbThread::VERSION,
387 })?;
388
389 let connection = connection.lock();
390
391 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
392 let data_type = DataType::Zstd;
393 let data = compressed;
394
395 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
396 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
397 "})?;
398
399 insert((id.0, title, updated_at, data_type, data))?;
400
401 Ok(())
402 }
403
404 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
405 let connection = self.connection.clone();
406
407 self.executor.spawn(async move {
408 let connection = connection.lock();
409
410 let mut select =
411 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
412 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
413 "})?;
414
415 let rows = select(())?;
416 let mut threads = Vec::new();
417
418 for (id, summary, updated_at) in rows {
419 threads.push(DbThreadMetadata {
420 id: acp::SessionId::new(id),
421 title: summary.into(),
422 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
423 });
424 }
425
426 Ok(threads)
427 })
428 }
429
430 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
431 let connection = self.connection.clone();
432
433 self.executor.spawn(async move {
434 let connection = connection.lock();
435 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
436 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
437 "})?;
438
439 let rows = select(id.0)?;
440 if let Some((data_type, data)) = rows.into_iter().next() {
441 let json_data = match data_type {
442 DataType::Zstd => {
443 let decompressed = zstd::decode_all(&data[..])?;
444 String::from_utf8(decompressed)?
445 }
446 DataType::Json => String::from_utf8(data)?,
447 };
448 let thread = DbThread::from_json(json_data.as_bytes())?;
449 Ok(Some(thread))
450 } else {
451 Ok(None)
452 }
453 })
454 }
455
456 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
457 let connection = self.connection.clone();
458
459 self.executor
460 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
461 }
462
463 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
464 let connection = self.connection.clone();
465
466 self.executor.spawn(async move {
467 let connection = connection.lock();
468
469 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
470 DELETE FROM threads WHERE id = ?
471 "})?;
472
473 delete(id.0)?;
474
475 Ok(())
476 })
477 }
478
479 pub fn delete_threads(&self) -> Task<Result<()>> {
480 let connection = self.connection.clone();
481
482 self.executor.spawn(async move {
483 let connection = connection.lock();
484
485 let mut delete = connection.exec_bound::<()>(indoc! {"
486 DELETE FROM threads
487 "})?;
488
489 delete(())?;
490
491 Ok(())
492 })
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use chrono::{DateTime, TimeZone, Utc};
500 use collections::HashMap;
501 use gpui::TestAppContext;
502 use std::sync::Arc;
503
504 #[test]
505 fn test_shared_thread_roundtrip() {
506 let original = SharedThread {
507 title: "Test Thread".into(),
508 messages: vec![],
509 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
510 model: None,
511 version: SharedThread::VERSION.to_string(),
512 };
513
514 let bytes = original.to_bytes().expect("Failed to serialize");
515 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
516
517 assert_eq!(restored.title, original.title);
518 assert_eq!(restored.version, original.version);
519 assert_eq!(restored.updated_at, original.updated_at);
520 }
521
522 #[test]
523 fn test_imported_flag_defaults_to_false() {
524 // Simulate deserializing a thread without the imported field (backwards compatibility).
525 let json = r#"{
526 "title": "Old Thread",
527 "messages": [],
528 "updated_at": "2024-01-01T00:00:00Z"
529 }"#;
530
531 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
532
533 assert!(
534 !db_thread.imported,
535 "Legacy threads without imported field should default to false"
536 );
537 }
538
539 fn session_id(value: &str) -> acp::SessionId {
540 acp::SessionId::new(Arc::<str>::from(value))
541 }
542
543 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
544 DbThread {
545 title: title.to_string().into(),
546 messages: Vec::new(),
547 updated_at,
548 detailed_summary: None,
549 initial_project_snapshot: None,
550 cumulative_token_usage: Default::default(),
551 request_token_usage: HashMap::default(),
552 model: None,
553 profile: None,
554 imported: false,
555 }
556 }
557
558 #[gpui::test]
559 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
560 let database = ThreadsDatabase::new(cx.executor()).unwrap();
561
562 let older_id = session_id("thread-a");
563 let newer_id = session_id("thread-b");
564
565 let older_thread = make_thread(
566 "Thread A",
567 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
568 );
569 let newer_thread = make_thread(
570 "Thread B",
571 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
572 );
573
574 database
575 .save_thread(older_id.clone(), older_thread)
576 .await
577 .unwrap();
578 database
579 .save_thread(newer_id.clone(), newer_thread)
580 .await
581 .unwrap();
582
583 let entries = database.list_threads().await.unwrap();
584 assert_eq!(entries.len(), 2);
585 assert_eq!(entries[0].id, newer_id);
586 assert_eq!(entries[1].id, older_id);
587 }
588
589 #[gpui::test]
590 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
591 let database = ThreadsDatabase::new(cx.executor()).unwrap();
592
593 let thread_id = session_id("thread-a");
594 let original_thread = make_thread(
595 "Thread A",
596 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
597 );
598 let updated_thread = make_thread(
599 "Thread B",
600 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
601 );
602
603 database
604 .save_thread(thread_id.clone(), original_thread)
605 .await
606 .unwrap();
607 database
608 .save_thread(thread_id.clone(), updated_thread)
609 .await
610 .unwrap();
611
612 let entries = database.list_threads().await.unwrap();
613 assert_eq!(entries.len(), 1);
614 assert_eq!(entries[0].id, thread_id);
615 assert_eq!(entries[0].title.as_ref(), "Thread B");
616 assert_eq!(
617 entries[0].updated_at,
618 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
619 );
620 }
621}