1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::{AgentProfileId, CompletionMode};
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct DbThread {
36 pub title: SharedString,
37 pub messages: Vec<DbMessage>,
38 pub updated_at: DateTime<Utc>,
39 #[serde(default)]
40 pub detailed_summary: Option<SharedString>,
41 #[serde(default)]
42 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
43 #[serde(default)]
44 pub cumulative_token_usage: language_model::TokenUsage,
45 #[serde(default)]
46 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
47 #[serde(default)]
48 pub model: Option<DbLanguageModel>,
49 #[serde(default)]
50 pub completion_mode: Option<CompletionMode>,
51 #[serde(default)]
52 pub profile: Option<AgentProfileId>,
53 #[serde(default)]
54 pub imported: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SharedThread {
59 pub title: SharedString,
60 pub messages: Vec<DbMessage>,
61 pub updated_at: DateTime<Utc>,
62 #[serde(default)]
63 pub model: Option<DbLanguageModel>,
64 #[serde(default)]
65 pub completion_mode: Option<CompletionMode>,
66 pub version: String,
67}
68
69impl SharedThread {
70 pub const VERSION: &'static str = "1.0.0";
71
72 pub fn from_db_thread(thread: &DbThread) -> Self {
73 Self {
74 title: thread.title.clone(),
75 messages: thread.messages.clone(),
76 updated_at: thread.updated_at,
77 model: thread.model.clone(),
78 completion_mode: thread.completion_mode,
79 version: Self::VERSION.to_string(),
80 }
81 }
82
83 pub fn to_db_thread(self) -> DbThread {
84 DbThread {
85 title: format!("🔗 {}", self.title).into(),
86 messages: self.messages,
87 updated_at: self.updated_at,
88 detailed_summary: None,
89 initial_project_snapshot: None,
90 cumulative_token_usage: Default::default(),
91 request_token_usage: Default::default(),
92 model: self.model,
93 completion_mode: self.completion_mode,
94 profile: None,
95 imported: true,
96 }
97 }
98
99 pub fn to_bytes(&self) -> Result<Vec<u8>> {
100 const COMPRESSION_LEVEL: i32 = 3;
101 let json = serde_json::to_vec(self)?;
102 let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
103 Ok(compressed)
104 }
105
106 pub fn from_bytes(data: &[u8]) -> Result<Self> {
107 let decompressed = zstd::decode_all(data)?;
108 Ok(serde_json::from_slice(&decompressed)?)
109 }
110}
111
112impl DbThread {
113 pub const VERSION: &'static str = "0.3.0";
114
115 pub fn from_json(json: &[u8]) -> Result<Self> {
116 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
117 match saved_thread_json.get("version") {
118 Some(serde_json::Value::String(version)) => match version.as_str() {
119 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
120 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
121 json,
122 )?),
123 },
124 _ => {
125 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
126 }
127 }
128 }
129
130 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
131 let mut messages = Vec::new();
132 let mut request_token_usage = HashMap::default();
133
134 let mut last_user_message_id = None;
135 for (ix, msg) in thread.messages.into_iter().enumerate() {
136 let message = match msg.role {
137 language_model::Role::User => {
138 let mut content = Vec::new();
139
140 // Convert segments to content
141 for segment in msg.segments {
142 match segment {
143 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
144 content.push(UserMessageContent::Text(text));
145 }
146 crate::legacy_thread::SerializedMessageSegment::Thinking {
147 text,
148 ..
149 } => {
150 // User messages don't have thinking segments, but handle gracefully
151 content.push(UserMessageContent::Text(text));
152 }
153 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
154 ..
155 } => {
156 // User messages don't have redacted thinking, skip.
157 }
158 }
159 }
160
161 // If no content was added, add context as text if available
162 if content.is_empty() && !msg.context.is_empty() {
163 content.push(UserMessageContent::Text(msg.context));
164 }
165
166 let id = UserMessageId::new();
167 last_user_message_id = Some(id.clone());
168
169 crate::Message::User(UserMessage {
170 // MessageId from old format can't be meaningfully converted, so generate a new one
171 id,
172 content,
173 })
174 }
175 language_model::Role::Assistant => {
176 let mut content = Vec::new();
177
178 // Convert segments to content
179 for segment in msg.segments {
180 match segment {
181 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
182 content.push(AgentMessageContent::Text(text));
183 }
184 crate::legacy_thread::SerializedMessageSegment::Thinking {
185 text,
186 signature,
187 } => {
188 content.push(AgentMessageContent::Thinking { text, signature });
189 }
190 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
191 data,
192 } => {
193 content.push(AgentMessageContent::RedactedThinking(data));
194 }
195 }
196 }
197
198 // Convert tool uses
199 let mut tool_names_by_id = HashMap::default();
200 for tool_use in msg.tool_uses {
201 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
202 content.push(AgentMessageContent::ToolUse(
203 language_model::LanguageModelToolUse {
204 id: tool_use.id,
205 name: tool_use.name.into(),
206 raw_input: serde_json::to_string(&tool_use.input)
207 .unwrap_or_default(),
208 input: tool_use.input,
209 is_input_complete: true,
210 thought_signature: None,
211 },
212 ));
213 }
214
215 // Convert tool results
216 let mut tool_results = IndexMap::default();
217 for tool_result in msg.tool_results {
218 let name = tool_names_by_id
219 .remove(&tool_result.tool_use_id)
220 .unwrap_or_else(|| SharedString::from("unknown"));
221 tool_results.insert(
222 tool_result.tool_use_id.clone(),
223 language_model::LanguageModelToolResult {
224 tool_use_id: tool_result.tool_use_id,
225 tool_name: name.into(),
226 is_error: tool_result.is_error,
227 content: tool_result.content,
228 output: tool_result.output,
229 },
230 );
231 }
232
233 if let Some(last_user_message_id) = &last_user_message_id
234 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
235 {
236 request_token_usage.insert(last_user_message_id.clone(), token_usage);
237 }
238
239 crate::Message::Agent(AgentMessage {
240 content,
241 tool_results,
242 reasoning_details: None,
243 })
244 }
245 language_model::Role::System => {
246 // Skip system messages as they're not supported in the new format
247 continue;
248 }
249 };
250
251 messages.push(message);
252 }
253
254 Ok(Self {
255 title: thread.summary,
256 messages,
257 updated_at: thread.updated_at,
258 detailed_summary: match thread.detailed_summary_state {
259 crate::legacy_thread::DetailedSummaryState::NotGenerated
260 | crate::legacy_thread::DetailedSummaryState::Generating => None,
261 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
262 },
263 initial_project_snapshot: thread.initial_project_snapshot,
264 cumulative_token_usage: thread.cumulative_token_usage,
265 request_token_usage,
266 model: thread.model,
267 completion_mode: thread.completion_mode,
268 profile: thread.profile,
269 imported: false,
270 })
271 }
272}
273
274#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
275pub enum DataType {
276 #[serde(rename = "json")]
277 Json,
278 #[serde(rename = "zstd")]
279 Zstd,
280}
281
282impl Bind for DataType {
283 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
284 let value = match self {
285 DataType::Json => "json",
286 DataType::Zstd => "zstd",
287 };
288 value.bind(statement, start_index)
289 }
290}
291
292impl Column for DataType {
293 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
294 let (value, next_index) = String::column(statement, start_index)?;
295 let data_type = match value.as_str() {
296 "json" => DataType::Json,
297 "zstd" => DataType::Zstd,
298 _ => anyhow::bail!("Unknown data type: {}", value),
299 };
300 Ok((data_type, next_index))
301 }
302}
303
304pub(crate) struct ThreadsDatabase {
305 executor: BackgroundExecutor,
306 connection: Arc<Mutex<Connection>>,
307}
308
309struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
310
311impl Global for GlobalThreadsDatabase {}
312
313impl ThreadsDatabase {
314 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
315 if cx.has_global::<GlobalThreadsDatabase>() {
316 return cx.global::<GlobalThreadsDatabase>().0.clone();
317 }
318 let executor = cx.background_executor().clone();
319 let task = executor
320 .spawn({
321 let executor = executor.clone();
322 async move {
323 match ThreadsDatabase::new(executor) {
324 Ok(db) => Ok(Arc::new(db)),
325 Err(err) => Err(Arc::new(err)),
326 }
327 }
328 })
329 .shared();
330
331 cx.set_global(GlobalThreadsDatabase(task.clone()));
332 task
333 }
334
335 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
336 let connection = if *ZED_STATELESS {
337 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
338 } else if cfg!(any(feature = "test-support", test)) {
339 // rust stores the name of the test on the current thread.
340 // We use this to automatically create a database that will
341 // be shared within the test (for the test_retrieve_old_thread)
342 // but not with concurrent tests.
343 let thread = std::thread::current();
344 let test_name = thread.name();
345 Connection::open_memory(Some(&format!(
346 "THREAD_FALLBACK_{}",
347 test_name.unwrap_or_default()
348 )))
349 } else {
350 let threads_dir = paths::data_dir().join("threads");
351 std::fs::create_dir_all(&threads_dir)?;
352 let sqlite_path = threads_dir.join("threads.db");
353 Connection::open_file(&sqlite_path.to_string_lossy())
354 };
355
356 connection.exec(indoc! {"
357 CREATE TABLE IF NOT EXISTS threads (
358 id TEXT PRIMARY KEY,
359 summary TEXT NOT NULL,
360 updated_at TEXT NOT NULL,
361 data_type TEXT NOT NULL,
362 data BLOB NOT NULL
363 )
364 "})?()
365 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
366
367 let db = Self {
368 executor,
369 connection: Arc::new(Mutex::new(connection)),
370 };
371
372 Ok(db)
373 }
374
375 fn save_thread_sync(
376 connection: &Arc<Mutex<Connection>>,
377 id: acp::SessionId,
378 thread: DbThread,
379 ) -> Result<()> {
380 const COMPRESSION_LEVEL: i32 = 3;
381
382 #[derive(Serialize)]
383 struct SerializedThread {
384 #[serde(flatten)]
385 thread: DbThread,
386 version: &'static str,
387 }
388
389 let title = thread.title.to_string();
390 let updated_at = thread.updated_at.to_rfc3339();
391 let json_data = serde_json::to_string(&SerializedThread {
392 thread,
393 version: DbThread::VERSION,
394 })?;
395
396 let connection = connection.lock();
397
398 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
399 let data_type = DataType::Zstd;
400 let data = compressed;
401
402 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
403 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
404 "})?;
405
406 insert((id.0, title, updated_at, data_type, data))?;
407
408 Ok(())
409 }
410
411 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
412 let connection = self.connection.clone();
413
414 self.executor.spawn(async move {
415 let connection = connection.lock();
416
417 let mut select =
418 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
419 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
420 "})?;
421
422 let rows = select(())?;
423 let mut threads = Vec::new();
424
425 for (id, summary, updated_at) in rows {
426 threads.push(DbThreadMetadata {
427 id: acp::SessionId::new(id),
428 title: summary.into(),
429 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
430 });
431 }
432
433 Ok(threads)
434 })
435 }
436
437 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
438 let connection = self.connection.clone();
439
440 self.executor.spawn(async move {
441 let connection = connection.lock();
442 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
443 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
444 "})?;
445
446 let rows = select(id.0)?;
447 if let Some((data_type, data)) = rows.into_iter().next() {
448 let json_data = match data_type {
449 DataType::Zstd => {
450 let decompressed = zstd::decode_all(&data[..])?;
451 String::from_utf8(decompressed)?
452 }
453 DataType::Json => String::from_utf8(data)?,
454 };
455 let thread = DbThread::from_json(json_data.as_bytes())?;
456 Ok(Some(thread))
457 } else {
458 Ok(None)
459 }
460 })
461 }
462
463 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
464 let connection = self.connection.clone();
465
466 self.executor
467 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
468 }
469
470 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
471 let connection = self.connection.clone();
472
473 self.executor.spawn(async move {
474 let connection = connection.lock();
475
476 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
477 DELETE FROM threads WHERE id = ?
478 "})?;
479
480 delete(id.0)?;
481
482 Ok(())
483 })
484 }
485
486 pub fn delete_threads(&self) -> Task<Result<()>> {
487 let connection = self.connection.clone();
488
489 self.executor.spawn(async move {
490 let connection = connection.lock();
491
492 let mut delete = connection.exec_bound::<()>(indoc! {"
493 DELETE FROM threads
494 "})?;
495
496 delete(())?;
497
498 Ok(())
499 })
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use chrono::{DateTime, TimeZone, Utc};
507 use collections::HashMap;
508 use gpui::TestAppContext;
509 use std::sync::Arc;
510
511 #[test]
512 fn test_shared_thread_roundtrip() {
513 let original = SharedThread {
514 title: "Test Thread".into(),
515 messages: vec![],
516 updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
517 model: None,
518 completion_mode: None,
519 version: SharedThread::VERSION.to_string(),
520 };
521
522 let bytes = original.to_bytes().expect("Failed to serialize");
523 let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
524
525 assert_eq!(restored.title, original.title);
526 assert_eq!(restored.version, original.version);
527 assert_eq!(restored.updated_at, original.updated_at);
528 }
529
530 #[test]
531 fn test_imported_flag_defaults_to_false() {
532 // Simulate deserializing a thread without the imported field (backwards compatibility).
533 let json = r#"{
534 "title": "Old Thread",
535 "messages": [],
536 "updated_at": "2024-01-01T00:00:00Z"
537 }"#;
538
539 let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
540
541 assert!(
542 !db_thread.imported,
543 "Legacy threads without imported field should default to false"
544 );
545 }
546
547 fn session_id(value: &str) -> acp::SessionId {
548 acp::SessionId::new(Arc::<str>::from(value))
549 }
550
551 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
552 DbThread {
553 title: title.to_string().into(),
554 messages: Vec::new(),
555 updated_at,
556 detailed_summary: None,
557 initial_project_snapshot: None,
558 cumulative_token_usage: Default::default(),
559 request_token_usage: HashMap::default(),
560 model: None,
561 completion_mode: None,
562 profile: None,
563 imported: false,
564 }
565 }
566
567 #[gpui::test]
568 async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
569 let database = ThreadsDatabase::new(cx.executor()).unwrap();
570
571 let older_id = session_id("thread-a");
572 let newer_id = session_id("thread-b");
573
574 let older_thread = make_thread(
575 "Thread A",
576 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
577 );
578 let newer_thread = make_thread(
579 "Thread B",
580 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
581 );
582
583 database
584 .save_thread(older_id.clone(), older_thread)
585 .await
586 .unwrap();
587 database
588 .save_thread(newer_id.clone(), newer_thread)
589 .await
590 .unwrap();
591
592 let entries = database.list_threads().await.unwrap();
593 assert_eq!(entries.len(), 2);
594 assert_eq!(entries[0].id, newer_id);
595 assert_eq!(entries[1].id, older_id);
596 }
597
598 #[gpui::test]
599 async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
600 let database = ThreadsDatabase::new(cx.executor()).unwrap();
601
602 let thread_id = session_id("thread-a");
603 let original_thread = make_thread(
604 "Thread A",
605 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
606 );
607 let updated_thread = make_thread(
608 "Thread B",
609 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
610 );
611
612 database
613 .save_thread(thread_id.clone(), original_thread)
614 .await
615 .unwrap();
616 database
617 .save_thread(thread_id.clone(), updated_thread)
618 .await
619 .unwrap();
620
621 let entries = database.list_threads().await.unwrap();
622 assert_eq!(entries.len(), 1);
623 assert_eq!(entries[0].id, thread_id);
624 assert_eq!(entries[0].title.as_ref(), "Thread B");
625 assert_eq!(
626 entries[0].updated_at,
627 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
628 );
629 }
630}