1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent::{thread::DetailedSummaryState, thread_store};
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, CompletionMode};
6use anyhow::{Result, anyhow};
7use chrono::{DateTime, Utc};
8use collections::{HashMap, IndexMap};
9use futures::{FutureExt, future::Shared};
10use gpui::{BackgroundExecutor, Global, Task};
11use indoc::indoc;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21use zed_env_vars::ZED_STATELESS;
22
23pub type DbMessage = crate::Message;
24pub type DbSummary = DetailedSummaryState;
25pub type DbLanguageModel = thread_store::SerializedLanguageModel;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct DbThreadMetadata {
29 pub id: acp::SessionId,
30 #[serde(alias = "summary")]
31 pub title: SharedString,
32 pub updated_at: DateTime<Utc>,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36pub struct DbThread {
37 pub title: SharedString,
38 pub messages: Vec<DbMessage>,
39 pub updated_at: DateTime<Utc>,
40 #[serde(default)]
41 pub detailed_summary: Option<SharedString>,
42 #[serde(default)]
43 pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
44 #[serde(default)]
45 pub cumulative_token_usage: language_model::TokenUsage,
46 #[serde(default)]
47 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
48 #[serde(default)]
49 pub model: Option<DbLanguageModel>,
50 #[serde(default)]
51 pub completion_mode: Option<CompletionMode>,
52 #[serde(default)]
53 pub profile: Option<AgentProfileId>,
54}
55
56impl DbThread {
57 pub const VERSION: &'static str = "0.3.0";
58
59 pub fn from_json(json: &[u8]) -> Result<Self> {
60 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
61 match saved_thread_json.get("version") {
62 Some(serde_json::Value::String(version)) => match version.as_str() {
63 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
64 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
65 },
66 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
67 }
68 }
69
70 fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
71 let mut messages = Vec::new();
72 let mut request_token_usage = HashMap::default();
73
74 let mut last_user_message_id = None;
75 for (ix, msg) in thread.messages.into_iter().enumerate() {
76 let message = match msg.role {
77 language_model::Role::User => {
78 let mut content = Vec::new();
79
80 // Convert segments to content
81 for segment in msg.segments {
82 match segment {
83 thread_store::SerializedMessageSegment::Text { text } => {
84 content.push(UserMessageContent::Text(text));
85 }
86 thread_store::SerializedMessageSegment::Thinking { text, .. } => {
87 // User messages don't have thinking segments, but handle gracefully
88 content.push(UserMessageContent::Text(text));
89 }
90 thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
91 // User messages don't have redacted thinking, skip.
92 }
93 }
94 }
95
96 // If no content was added, add context as text if available
97 if content.is_empty() && !msg.context.is_empty() {
98 content.push(UserMessageContent::Text(msg.context));
99 }
100
101 let id = UserMessageId::new();
102 last_user_message_id = Some(id.clone());
103
104 crate::Message::User(UserMessage {
105 // MessageId from old format can't be meaningfully converted, so generate a new one
106 id,
107 content,
108 })
109 }
110 language_model::Role::Assistant => {
111 let mut content = Vec::new();
112
113 // Convert segments to content
114 for segment in msg.segments {
115 match segment {
116 thread_store::SerializedMessageSegment::Text { text } => {
117 content.push(AgentMessageContent::Text(text));
118 }
119 thread_store::SerializedMessageSegment::Thinking {
120 text,
121 signature,
122 } => {
123 content.push(AgentMessageContent::Thinking { text, signature });
124 }
125 thread_store::SerializedMessageSegment::RedactedThinking { data } => {
126 content.push(AgentMessageContent::RedactedThinking(data));
127 }
128 }
129 }
130
131 // Convert tool uses
132 let mut tool_names_by_id = HashMap::default();
133 for tool_use in msg.tool_uses {
134 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
135 content.push(AgentMessageContent::ToolUse(
136 language_model::LanguageModelToolUse {
137 id: tool_use.id,
138 name: tool_use.name.into(),
139 raw_input: serde_json::to_string(&tool_use.input)
140 .unwrap_or_default(),
141 input: tool_use.input,
142 is_input_complete: true,
143 },
144 ));
145 }
146
147 // Convert tool results
148 let mut tool_results = IndexMap::default();
149 for tool_result in msg.tool_results {
150 let name = tool_names_by_id
151 .remove(&tool_result.tool_use_id)
152 .unwrap_or_else(|| SharedString::from("unknown"));
153 tool_results.insert(
154 tool_result.tool_use_id.clone(),
155 language_model::LanguageModelToolResult {
156 tool_use_id: tool_result.tool_use_id,
157 tool_name: name.into(),
158 is_error: tool_result.is_error,
159 content: tool_result.content,
160 output: tool_result.output,
161 },
162 );
163 }
164
165 if let Some(last_user_message_id) = &last_user_message_id
166 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
167 {
168 request_token_usage.insert(last_user_message_id.clone(), token_usage);
169 }
170
171 crate::Message::Agent(AgentMessage {
172 content,
173 tool_results,
174 })
175 }
176 language_model::Role::System => {
177 // Skip system messages as they're not supported in the new format
178 continue;
179 }
180 };
181
182 messages.push(message);
183 }
184
185 Ok(Self {
186 title: thread.summary,
187 messages,
188 updated_at: thread.updated_at,
189 detailed_summary: match thread.detailed_summary_state {
190 DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => {
191 None
192 }
193 DetailedSummaryState::Generated { text, .. } => Some(text),
194 },
195 initial_project_snapshot: thread.initial_project_snapshot,
196 cumulative_token_usage: thread.cumulative_token_usage,
197 request_token_usage,
198 model: thread.model,
199 completion_mode: thread.completion_mode,
200 profile: thread.profile,
201 })
202 }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
206pub enum DataType {
207 #[serde(rename = "json")]
208 Json,
209 #[serde(rename = "zstd")]
210 Zstd,
211}
212
213impl Bind for DataType {
214 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
215 let value = match self {
216 DataType::Json => "json",
217 DataType::Zstd => "zstd",
218 };
219 value.bind(statement, start_index)
220 }
221}
222
223impl Column for DataType {
224 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
225 let (value, next_index) = String::column(statement, start_index)?;
226 let data_type = match value.as_str() {
227 "json" => DataType::Json,
228 "zstd" => DataType::Zstd,
229 _ => anyhow::bail!("Unknown data type: {}", value),
230 };
231 Ok((data_type, next_index))
232 }
233}
234
235pub(crate) struct ThreadsDatabase {
236 executor: BackgroundExecutor,
237 connection: Arc<Mutex<Connection>>,
238}
239
240struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
241
242impl Global for GlobalThreadsDatabase {}
243
244impl ThreadsDatabase {
245 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
246 if cx.has_global::<GlobalThreadsDatabase>() {
247 return cx.global::<GlobalThreadsDatabase>().0.clone();
248 }
249 let executor = cx.background_executor().clone();
250 let task = executor
251 .spawn({
252 let executor = executor.clone();
253 async move {
254 match ThreadsDatabase::new(executor) {
255 Ok(db) => Ok(Arc::new(db)),
256 Err(err) => Err(Arc::new(err)),
257 }
258 }
259 })
260 .shared();
261
262 cx.set_global(GlobalThreadsDatabase(task.clone()));
263 task
264 }
265
266 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
267 let connection = if *ZED_STATELESS {
268 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
269 } else if cfg!(any(feature = "test-support", test)) {
270 // rust stores the name of the test on the current thread.
271 // We use this to automatically create a database that will
272 // be shared within the test (for the test_retrieve_old_thread)
273 // but not with concurrent tests.
274 let thread = std::thread::current();
275 let test_name = thread.name();
276 Connection::open_memory(Some(&format!(
277 "THREAD_FALLBACK_{}",
278 test_name.unwrap_or_default()
279 )))
280 } else {
281 let threads_dir = paths::data_dir().join("threads");
282 std::fs::create_dir_all(&threads_dir)?;
283 let sqlite_path = threads_dir.join("threads.db");
284 Connection::open_file(&sqlite_path.to_string_lossy())
285 };
286
287 connection.exec(indoc! {"
288 CREATE TABLE IF NOT EXISTS threads (
289 id TEXT PRIMARY KEY,
290 summary TEXT NOT NULL,
291 updated_at TEXT NOT NULL,
292 data_type TEXT NOT NULL,
293 data BLOB NOT NULL
294 )
295 "})?()
296 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
297
298 let db = Self {
299 executor,
300 connection: Arc::new(Mutex::new(connection)),
301 };
302
303 Ok(db)
304 }
305
306 fn save_thread_sync(
307 connection: &Arc<Mutex<Connection>>,
308 id: acp::SessionId,
309 thread: DbThread,
310 ) -> Result<()> {
311 const COMPRESSION_LEVEL: i32 = 3;
312
313 #[derive(Serialize)]
314 struct SerializedThread {
315 #[serde(flatten)]
316 thread: DbThread,
317 version: &'static str,
318 }
319
320 let title = thread.title.to_string();
321 let updated_at = thread.updated_at.to_rfc3339();
322 let json_data = serde_json::to_string(&SerializedThread {
323 thread,
324 version: DbThread::VERSION,
325 })?;
326
327 let connection = connection.lock();
328
329 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
330 let data_type = DataType::Zstd;
331 let data = compressed;
332
333 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
334 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
335 "})?;
336
337 insert((id.0, title, updated_at, data_type, data))?;
338
339 Ok(())
340 }
341
342 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
343 let connection = self.connection.clone();
344
345 self.executor.spawn(async move {
346 let connection = connection.lock();
347
348 let mut select =
349 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
350 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
351 "})?;
352
353 let rows = select(())?;
354 let mut threads = Vec::new();
355
356 for (id, summary, updated_at) in rows {
357 threads.push(DbThreadMetadata {
358 id: acp::SessionId(id),
359 title: summary.into(),
360 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
361 });
362 }
363
364 Ok(threads)
365 })
366 }
367
368 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
369 let connection = self.connection.clone();
370
371 self.executor.spawn(async move {
372 let connection = connection.lock();
373 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
374 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
375 "})?;
376
377 let rows = select(id.0)?;
378 if let Some((data_type, data)) = rows.into_iter().next() {
379 let json_data = match data_type {
380 DataType::Zstd => {
381 let decompressed = zstd::decode_all(&data[..])?;
382 String::from_utf8(decompressed)?
383 }
384 DataType::Json => String::from_utf8(data)?,
385 };
386 let thread = DbThread::from_json(json_data.as_bytes())?;
387 Ok(Some(thread))
388 } else {
389 Ok(None)
390 }
391 })
392 }
393
394 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
395 let connection = self.connection.clone();
396
397 self.executor
398 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
399 }
400
401 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
402 let connection = self.connection.clone();
403
404 self.executor.spawn(async move {
405 let connection = connection.lock();
406
407 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
408 DELETE FROM threads WHERE id = ?
409 "})?;
410
411 delete(id.0)?;
412
413 Ok(())
414 })
415 }
416}
417
418#[cfg(test)]
419mod tests {
420
421 use super::*;
422 use agent::MessageSegment;
423 use agent::context::LoadedContext;
424 use client::Client;
425 use fs::FakeFs;
426 use gpui::AppContext;
427 use gpui::TestAppContext;
428 use http_client::FakeHttpClient;
429 use language_model::Role;
430 use project::Project;
431 use settings::SettingsStore;
432
433 fn init_test(cx: &mut TestAppContext) {
434 env_logger::try_init().ok();
435 cx.update(|cx| {
436 let settings_store = SettingsStore::test(cx);
437 cx.set_global(settings_store);
438 Project::init_settings(cx);
439 language::init(cx);
440
441 let http_client = FakeHttpClient::with_404_response();
442 let clock = Arc::new(clock::FakeSystemClock::new());
443 let client = Client::new(clock, http_client, cx);
444 agent::init(cx);
445 agent_settings::init(cx);
446 language_model::init(client, cx);
447 });
448 }
449
450 #[gpui::test]
451 async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
452 init_test(cx);
453 let fs = FakeFs::new(cx.executor());
454 let project = Project::test(fs, [], cx).await;
455
456 // Save a thread using the old agent.
457 let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
458 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
459 thread.update(cx, |thread, cx| {
460 thread.insert_message(
461 Role::User,
462 vec![MessageSegment::Text("Hey!".into())],
463 LoadedContext::default(),
464 vec![],
465 false,
466 cx,
467 );
468 thread.insert_message(
469 Role::Assistant,
470 vec![MessageSegment::Text("How're you doing?".into())],
471 LoadedContext::default(),
472 vec![],
473 false,
474 cx,
475 )
476 });
477 thread_store
478 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
479 .await
480 .unwrap();
481
482 // Open that same thread using the new agent.
483 let db = cx.update(ThreadsDatabase::connect).await.unwrap();
484 let threads = db.list_threads().await.unwrap();
485 assert_eq!(threads.len(), 1);
486 let thread = db
487 .load_thread(threads[0].id.clone())
488 .await
489 .unwrap()
490 .unwrap();
491 assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
492 assert_eq!(
493 thread.messages[1].to_markdown(),
494 "## Assistant\n\nHow're you doing?\n"
495 );
496 }
497}