1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent::{thread::DetailedSummaryState, thread_store};
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, CompletionMode};
6use anyhow::{Result, anyhow};
7use chrono::{DateTime, Utc};
8use collections::{HashMap, IndexMap};
9use futures::{FutureExt, future::Shared};
10use gpui::{BackgroundExecutor, Global, Task};
11use indoc::indoc;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = DetailedSummaryState;
24pub type DbLanguageModel = thread_store::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct DbThread {
36 pub title: SharedString,
37 pub messages: Vec<DbMessage>,
38 pub updated_at: DateTime<Utc>,
39 #[serde(default)]
40 pub detailed_summary: Option<SharedString>,
41 #[serde(default)]
42 pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
43 #[serde(default)]
44 pub cumulative_token_usage: language_model::TokenUsage,
45 #[serde(default)]
46 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
47 #[serde(default)]
48 pub model: Option<DbLanguageModel>,
49 #[serde(default)]
50 pub completion_mode: Option<CompletionMode>,
51 #[serde(default)]
52 pub profile: Option<AgentProfileId>,
53}
54
55impl DbThread {
56 pub const VERSION: &'static str = "0.3.0";
57
58 pub fn from_json(json: &[u8]) -> Result<Self> {
59 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
60 match saved_thread_json.get("version") {
61 Some(serde_json::Value::String(version)) => match version.as_str() {
62 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
63 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
64 },
65 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
66 }
67 }
68
69 fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
70 let mut messages = Vec::new();
71 let mut request_token_usage = HashMap::default();
72
73 let mut last_user_message_id = None;
74 for (ix, msg) in thread.messages.into_iter().enumerate() {
75 let message = match msg.role {
76 language_model::Role::User => {
77 let mut content = Vec::new();
78
79 // Convert segments to content
80 for segment in msg.segments {
81 match segment {
82 thread_store::SerializedMessageSegment::Text { text } => {
83 content.push(UserMessageContent::Text(text));
84 }
85 thread_store::SerializedMessageSegment::Thinking { text, .. } => {
86 // User messages don't have thinking segments, but handle gracefully
87 content.push(UserMessageContent::Text(text));
88 }
89 thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
90 // User messages don't have redacted thinking, skip.
91 }
92 }
93 }
94
95 // If no content was added, add context as text if available
96 if content.is_empty() && !msg.context.is_empty() {
97 content.push(UserMessageContent::Text(msg.context));
98 }
99
100 let id = UserMessageId::new();
101 last_user_message_id = Some(id.clone());
102
103 crate::Message::User(UserMessage {
104 // MessageId from old format can't be meaningfully converted, so generate a new one
105 id,
106 content,
107 })
108 }
109 language_model::Role::Assistant => {
110 let mut content = Vec::new();
111
112 // Convert segments to content
113 for segment in msg.segments {
114 match segment {
115 thread_store::SerializedMessageSegment::Text { text } => {
116 content.push(AgentMessageContent::Text(text));
117 }
118 thread_store::SerializedMessageSegment::Thinking {
119 text,
120 signature,
121 } => {
122 content.push(AgentMessageContent::Thinking { text, signature });
123 }
124 thread_store::SerializedMessageSegment::RedactedThinking { data } => {
125 content.push(AgentMessageContent::RedactedThinking(data));
126 }
127 }
128 }
129
130 // Convert tool uses
131 let mut tool_names_by_id = HashMap::default();
132 for tool_use in msg.tool_uses {
133 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
134 content.push(AgentMessageContent::ToolUse(
135 language_model::LanguageModelToolUse {
136 id: tool_use.id,
137 name: tool_use.name.into(),
138 raw_input: serde_json::to_string(&tool_use.input)
139 .unwrap_or_default(),
140 input: tool_use.input,
141 is_input_complete: true,
142 },
143 ));
144 }
145
146 // Convert tool results
147 let mut tool_results = IndexMap::default();
148 for tool_result in msg.tool_results {
149 let name = tool_names_by_id
150 .remove(&tool_result.tool_use_id)
151 .unwrap_or_else(|| SharedString::from("unknown"));
152 tool_results.insert(
153 tool_result.tool_use_id.clone(),
154 language_model::LanguageModelToolResult {
155 tool_use_id: tool_result.tool_use_id,
156 tool_name: name.into(),
157 is_error: tool_result.is_error,
158 content: tool_result.content,
159 output: tool_result.output,
160 },
161 );
162 }
163
164 if let Some(last_user_message_id) = &last_user_message_id
165 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
166 {
167 request_token_usage.insert(last_user_message_id.clone(), token_usage);
168 }
169
170 crate::Message::Agent(AgentMessage {
171 content,
172 tool_results,
173 })
174 }
175 language_model::Role::System => {
176 // Skip system messages as they're not supported in the new format
177 continue;
178 }
179 };
180
181 messages.push(message);
182 }
183
184 Ok(Self {
185 title: thread.summary,
186 messages,
187 updated_at: thread.updated_at,
188 detailed_summary: match thread.detailed_summary_state {
189 DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => {
190 None
191 }
192 DetailedSummaryState::Generated { text, .. } => Some(text),
193 },
194 initial_project_snapshot: thread.initial_project_snapshot,
195 cumulative_token_usage: thread.cumulative_token_usage,
196 request_token_usage,
197 model: thread.model,
198 completion_mode: thread.completion_mode,
199 profile: thread.profile,
200 })
201 }
202}
203
204pub static ZED_STATELESS: std::sync::LazyLock<bool> =
205 std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
206
207#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
208pub enum DataType {
209 #[serde(rename = "json")]
210 Json,
211 #[serde(rename = "zstd")]
212 Zstd,
213}
214
215impl Bind for DataType {
216 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
217 let value = match self {
218 DataType::Json => "json",
219 DataType::Zstd => "zstd",
220 };
221 value.bind(statement, start_index)
222 }
223}
224
225impl Column for DataType {
226 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
227 let (value, next_index) = String::column(statement, start_index)?;
228 let data_type = match value.as_str() {
229 "json" => DataType::Json,
230 "zstd" => DataType::Zstd,
231 _ => anyhow::bail!("Unknown data type: {}", value),
232 };
233 Ok((data_type, next_index))
234 }
235}
236
237pub(crate) struct ThreadsDatabase {
238 executor: BackgroundExecutor,
239 connection: Arc<Mutex<Connection>>,
240}
241
242struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
243
244impl Global for GlobalThreadsDatabase {}
245
246impl ThreadsDatabase {
247 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
248 if cx.has_global::<GlobalThreadsDatabase>() {
249 return cx.global::<GlobalThreadsDatabase>().0.clone();
250 }
251 let executor = cx.background_executor().clone();
252 let task = executor
253 .spawn({
254 let executor = executor.clone();
255 async move {
256 match ThreadsDatabase::new(executor) {
257 Ok(db) => Ok(Arc::new(db)),
258 Err(err) => Err(Arc::new(err)),
259 }
260 }
261 })
262 .shared();
263
264 cx.set_global(GlobalThreadsDatabase(task.clone()));
265 task
266 }
267
268 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
269 let connection = if *ZED_STATELESS {
270 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
271 } else if cfg!(any(feature = "test-support", test)) {
272 // rust stores the name of the test on the current thread.
273 // We use this to automatically create a database that will
274 // be shared within the test (for the test_retrieve_old_thread)
275 // but not with concurrent tests.
276 let thread = std::thread::current();
277 let test_name = thread.name();
278 Connection::open_memory(Some(&format!(
279 "THREAD_FALLBACK_{}",
280 test_name.unwrap_or_default()
281 )))
282 } else {
283 let threads_dir = paths::data_dir().join("threads");
284 std::fs::create_dir_all(&threads_dir)?;
285 let sqlite_path = threads_dir.join("threads.db");
286 Connection::open_file(&sqlite_path.to_string_lossy())
287 };
288
289 connection.exec(indoc! {"
290 CREATE TABLE IF NOT EXISTS threads (
291 id TEXT PRIMARY KEY,
292 summary TEXT NOT NULL,
293 updated_at TEXT NOT NULL,
294 data_type TEXT NOT NULL,
295 data BLOB NOT NULL
296 )
297 "})?()
298 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
299
300 let db = Self {
301 executor,
302 connection: Arc::new(Mutex::new(connection)),
303 };
304
305 Ok(db)
306 }
307
308 fn save_thread_sync(
309 connection: &Arc<Mutex<Connection>>,
310 id: acp::SessionId,
311 thread: DbThread,
312 ) -> Result<()> {
313 const COMPRESSION_LEVEL: i32 = 3;
314
315 #[derive(Serialize)]
316 struct SerializedThread {
317 #[serde(flatten)]
318 thread: DbThread,
319 version: &'static str,
320 }
321
322 let title = thread.title.to_string();
323 let updated_at = thread.updated_at.to_rfc3339();
324 let json_data = serde_json::to_string(&SerializedThread {
325 thread,
326 version: DbThread::VERSION,
327 })?;
328
329 let connection = connection.lock();
330
331 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
332 let data_type = DataType::Zstd;
333 let data = compressed;
334
335 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
336 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
337 "})?;
338
339 insert((id.0, title, updated_at, data_type, data))?;
340
341 Ok(())
342 }
343
344 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
345 let connection = self.connection.clone();
346
347 self.executor.spawn(async move {
348 let connection = connection.lock();
349
350 let mut select =
351 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
352 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
353 "})?;
354
355 let rows = select(())?;
356 let mut threads = Vec::new();
357
358 for (id, summary, updated_at) in rows {
359 threads.push(DbThreadMetadata {
360 id: acp::SessionId(id),
361 title: summary.into(),
362 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
363 });
364 }
365
366 Ok(threads)
367 })
368 }
369
370 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
371 let connection = self.connection.clone();
372
373 self.executor.spawn(async move {
374 let connection = connection.lock();
375 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
376 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
377 "})?;
378
379 let rows = select(id.0)?;
380 if let Some((data_type, data)) = rows.into_iter().next() {
381 let json_data = match data_type {
382 DataType::Zstd => {
383 let decompressed = zstd::decode_all(&data[..])?;
384 String::from_utf8(decompressed)?
385 }
386 DataType::Json => String::from_utf8(data)?,
387 };
388 let thread = DbThread::from_json(json_data.as_bytes())?;
389 Ok(Some(thread))
390 } else {
391 Ok(None)
392 }
393 })
394 }
395
396 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
397 let connection = self.connection.clone();
398
399 self.executor
400 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
401 }
402
403 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
404 let connection = self.connection.clone();
405
406 self.executor.spawn(async move {
407 let connection = connection.lock();
408
409 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
410 DELETE FROM threads WHERE id = ?
411 "})?;
412
413 delete(id.0)?;
414
415 Ok(())
416 })
417 }
418}
419
420#[cfg(test)]
421mod tests {
422
423 use super::*;
424 use agent::MessageSegment;
425 use agent::context::LoadedContext;
426 use client::Client;
427 use fs::FakeFs;
428 use gpui::AppContext;
429 use gpui::TestAppContext;
430 use http_client::FakeHttpClient;
431 use language_model::Role;
432 use project::Project;
433 use settings::SettingsStore;
434
435 fn init_test(cx: &mut TestAppContext) {
436 env_logger::try_init().ok();
437 cx.update(|cx| {
438 let settings_store = SettingsStore::test(cx);
439 cx.set_global(settings_store);
440 Project::init_settings(cx);
441 language::init(cx);
442
443 let http_client = FakeHttpClient::with_404_response();
444 let clock = Arc::new(clock::FakeSystemClock::new());
445 let client = Client::new(clock, http_client, cx);
446 agent::init(cx);
447 agent_settings::init(cx);
448 language_model::init(client, cx);
449 });
450 }
451
452 #[gpui::test]
453 async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
454 init_test(cx);
455 let fs = FakeFs::new(cx.executor());
456 let project = Project::test(fs, [], cx).await;
457
458 // Save a thread using the old agent.
459 let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
460 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
461 thread.update(cx, |thread, cx| {
462 thread.insert_message(
463 Role::User,
464 vec![MessageSegment::Text("Hey!".into())],
465 LoadedContext::default(),
466 vec![],
467 false,
468 cx,
469 );
470 thread.insert_message(
471 Role::Assistant,
472 vec![MessageSegment::Text("How're you doing?".into())],
473 LoadedContext::default(),
474 vec![],
475 false,
476 cx,
477 )
478 });
479 thread_store
480 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
481 .await
482 .unwrap();
483
484 // Open that same thread using the new agent.
485 let db = cx.update(ThreadsDatabase::connect).await.unwrap();
486 let threads = db.list_threads().await.unwrap();
487 assert_eq!(threads.len(), 1);
488 let thread = db
489 .load_thread(threads[0].id.clone())
490 .await
491 .unwrap()
492 .unwrap();
493 assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
494 assert_eq!(
495 thread.messages[1].to_markdown(),
496 "## Assistant\n\nHow're you doing?\n"
497 );
498 }
499}