1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::{AcpThreadMetadata, AgentServerName};
3use agent::thread_store;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, CompletionMode};
6use anyhow::{Result, anyhow};
7use chrono::{DateTime, Utc};
8use collections::{HashMap, IndexMap};
9use futures::{FutureExt, future::Shared};
10use gpui::{BackgroundExecutor, Global, Task};
11use indoc::indoc;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use sqlez::{
15 bindable::{Bind, Column},
16 connection::Connection,
17 statement::Statement,
18};
19use std::sync::Arc;
20use ui::{App, SharedString};
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = agent::thread::DetailedSummaryState;
24pub type DbLanguageModel = thread_store::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34impl DbThreadMetadata {
35 pub fn to_acp(self, agent: AgentServerName) -> AcpThreadMetadata {
36 AcpThreadMetadata {
37 agent,
38 id: self.id,
39 title: self.title,
40 updated_at: self.updated_at,
41 }
42 }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46pub struct DbThread {
47 pub title: SharedString,
48 pub messages: Vec<DbMessage>,
49 pub updated_at: DateTime<Utc>,
50 #[serde(default)]
51 pub summary: DbSummary,
52 #[serde(default)]
53 pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
54 #[serde(default)]
55 pub cumulative_token_usage: language_model::TokenUsage,
56 #[serde(default)]
57 pub request_token_usage: Vec<language_model::TokenUsage>,
58 #[serde(default)]
59 pub model: Option<DbLanguageModel>,
60 #[serde(default)]
61 pub completion_mode: Option<CompletionMode>,
62 #[serde(default)]
63 pub profile: Option<AgentProfileId>,
64}
65
66impl DbThread {
67 pub const VERSION: &'static str = "0.3.0";
68
69 pub fn from_json(json: &[u8]) -> Result<Self> {
70 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
71 match saved_thread_json.get("version") {
72 Some(serde_json::Value::String(version)) => match version.as_str() {
73 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
74 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
75 },
76 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
77 }
78 }
79
80 fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
81 let mut messages = Vec::new();
82 for msg in thread.messages {
83 let message = match msg.role {
84 language_model::Role::User => {
85 let mut content = Vec::new();
86
87 // Convert segments to content
88 for segment in msg.segments {
89 match segment {
90 thread_store::SerializedMessageSegment::Text { text } => {
91 content.push(UserMessageContent::Text(text));
92 }
93 thread_store::SerializedMessageSegment::Thinking { text, .. } => {
94 // User messages don't have thinking segments, but handle gracefully
95 content.push(UserMessageContent::Text(text));
96 }
97 thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
98 // User messages don't have redacted thinking, skip.
99 }
100 }
101 }
102
103 // If no content was added, add context as text if available
104 if content.is_empty() && !msg.context.is_empty() {
105 content.push(UserMessageContent::Text(msg.context));
106 }
107
108 crate::Message::User(UserMessage {
109 // MessageId from old format can't be meaningfully converted, so generate a new one
110 id: acp_thread::UserMessageId::new(),
111 content,
112 })
113 }
114 language_model::Role::Assistant => {
115 let mut content = Vec::new();
116
117 // Convert segments to content
118 for segment in msg.segments {
119 match segment {
120 thread_store::SerializedMessageSegment::Text { text } => {
121 content.push(AgentMessageContent::Text(text));
122 }
123 thread_store::SerializedMessageSegment::Thinking {
124 text,
125 signature,
126 } => {
127 content.push(AgentMessageContent::Thinking { text, signature });
128 }
129 thread_store::SerializedMessageSegment::RedactedThinking { data } => {
130 content.push(AgentMessageContent::RedactedThinking(data));
131 }
132 }
133 }
134
135 // Convert tool uses
136 let mut tool_names_by_id = HashMap::default();
137 for tool_use in msg.tool_uses {
138 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
139 content.push(AgentMessageContent::ToolUse(
140 language_model::LanguageModelToolUse {
141 id: tool_use.id,
142 name: tool_use.name.into(),
143 raw_input: serde_json::to_string(&tool_use.input)
144 .unwrap_or_default(),
145 input: tool_use.input,
146 is_input_complete: true,
147 },
148 ));
149 }
150
151 // Convert tool results
152 let mut tool_results = IndexMap::default();
153 for tool_result in msg.tool_results {
154 let name = tool_names_by_id
155 .remove(&tool_result.tool_use_id)
156 .unwrap_or_else(|| SharedString::from("unknown"));
157 tool_results.insert(
158 tool_result.tool_use_id.clone(),
159 language_model::LanguageModelToolResult {
160 tool_use_id: tool_result.tool_use_id,
161 tool_name: name.into(),
162 is_error: tool_result.is_error,
163 content: tool_result.content,
164 output: tool_result.output,
165 },
166 );
167 }
168
169 crate::Message::Agent(AgentMessage {
170 content,
171 tool_results,
172 })
173 }
174 language_model::Role::System => {
175 // Skip system messages as they're not supported in the new format
176 continue;
177 }
178 };
179
180 messages.push(message);
181 }
182
183 Ok(Self {
184 title: thread.summary,
185 messages,
186 updated_at: thread.updated_at,
187 summary: thread.detailed_summary_state,
188 initial_project_snapshot: thread.initial_project_snapshot,
189 cumulative_token_usage: thread.cumulative_token_usage,
190 request_token_usage: thread.request_token_usage,
191 model: thread.model,
192 completion_mode: thread.completion_mode,
193 profile: thread.profile,
194 })
195 }
196}
197
198pub static ZED_STATELESS: std::sync::LazyLock<bool> =
199 std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
200
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202pub enum DataType {
203 #[serde(rename = "json")]
204 Json,
205 #[serde(rename = "zstd")]
206 Zstd,
207}
208
209impl Bind for DataType {
210 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
211 let value = match self {
212 DataType::Json => "json",
213 DataType::Zstd => "zstd",
214 };
215 value.bind(statement, start_index)
216 }
217}
218
219impl Column for DataType {
220 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
221 let (value, next_index) = String::column(statement, start_index)?;
222 let data_type = match value.as_str() {
223 "json" => DataType::Json,
224 "zstd" => DataType::Zstd,
225 _ => anyhow::bail!("Unknown data type: {}", value),
226 };
227 Ok((data_type, next_index))
228 }
229}
230
231pub(crate) struct ThreadsDatabase {
232 executor: BackgroundExecutor,
233 connection: Arc<Mutex<Connection>>,
234}
235
236struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
237
238impl Global for GlobalThreadsDatabase {}
239
240impl ThreadsDatabase {
241 fn connection(&self) -> Arc<Mutex<Connection>> {
242 self.connection.clone()
243 }
244
245 const COMPRESSION_LEVEL: i32 = 3;
246}
247
248impl ThreadsDatabase {
249 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
250 if cx.has_global::<GlobalThreadsDatabase>() {
251 return cx.global::<GlobalThreadsDatabase>().0.clone();
252 }
253 let executor = cx.background_executor().clone();
254 let task = executor
255 .spawn({
256 let executor = executor.clone();
257 async move {
258 match ThreadsDatabase::new(executor) {
259 Ok(db) => Ok(Arc::new(db)),
260 Err(err) => Err(Arc::new(err)),
261 }
262 }
263 })
264 .shared();
265
266 cx.set_global(GlobalThreadsDatabase(task.clone()));
267 task
268 }
269
270 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
271 let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
272 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
273 } else {
274 let threads_dir = paths::data_dir().join("threads");
275 std::fs::create_dir_all(&threads_dir)?;
276 let sqlite_path = threads_dir.join("threads.db");
277 Connection::open_file(&sqlite_path.to_string_lossy())
278 };
279
280 connection.exec(indoc! {"
281 CREATE TABLE IF NOT EXISTS threads (
282 id TEXT PRIMARY KEY,
283 summary TEXT NOT NULL,
284 updated_at TEXT NOT NULL,
285 data_type TEXT NOT NULL,
286 data BLOB NOT NULL
287 )
288 "})?()
289 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
290
291 let db = Self {
292 executor: executor.clone(),
293 connection: Arc::new(Mutex::new(connection)),
294 };
295
296 Ok(db)
297 }
298
299 fn save_thread_sync(
300 connection: &Arc<Mutex<Connection>>,
301 id: acp::SessionId,
302 thread: DbThread,
303 ) -> Result<DbThreadMetadata> {
304 let json_data = serde_json::to_string(&thread)?;
305 let title = thread.title.to_string();
306 let updated_at = thread.updated_at.to_rfc3339();
307
308 let connection = connection.lock();
309
310 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
311 let data_type = DataType::Zstd;
312 let data = compressed;
313
314 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
315 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
316 "})?;
317
318 insert((id.0.clone(), title, updated_at, data_type, data))?;
319
320 Ok(DbThreadMetadata {
321 id,
322 title: thread.title,
323 updated_at: thread.updated_at,
324 })
325 }
326
327 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
328 let connection = self.connection.clone();
329
330 self.executor.spawn(async move {
331 let connection = connection.lock();
332 let mut select =
333 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
334 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
335 "})?;
336
337 let rows = select(())?;
338 let mut threads = Vec::new();
339
340 for (id, summary, updated_at) in rows {
341 threads.push(DbThreadMetadata {
342 id: acp::SessionId(id),
343 title: summary.into(),
344 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
345 });
346 }
347
348 Ok(threads)
349 })
350 }
351
352 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
353 let connection = self.connection.clone();
354
355 self.executor.spawn(async move {
356 let connection = connection.lock();
357 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
358 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
359 "})?;
360
361 let rows = select(id.0)?;
362 if let Some((data_type, data)) = rows.into_iter().next() {
363 let json_data = match data_type {
364 DataType::Zstd => {
365 let decompressed = zstd::decode_all(&data[..])?;
366 String::from_utf8(decompressed)?
367 }
368 DataType::Json => String::from_utf8(data)?,
369 };
370 dbg!(&json_data);
371
372 let thread = dbg!(DbThread::from_json(json_data.as_bytes()))?;
373 Ok(Some(thread))
374 } else {
375 Ok(None)
376 }
377 })
378 }
379
380 pub fn save_thread(
381 &self,
382 id: acp::SessionId,
383 thread: DbThread,
384 ) -> Task<Result<DbThreadMetadata>> {
385 let connection = self.connection.clone();
386
387 self.executor
388 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
389 }
390
391 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
392 let connection = self.connection.clone();
393
394 self.executor.spawn(async move {
395 let connection = connection.lock();
396
397 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
398 DELETE FROM threads WHERE id = ?
399 "})?;
400
401 delete(id.0)?;
402
403 Ok(())
404 })
405 }
406}
407
408#[cfg(test)]
409mod tests {
410
411 use super::*;
412 use agent::MessageSegment;
413 use agent::context::LoadedContext;
414 use client::Client;
415 use fs::FakeFs;
416 use gpui::AppContext;
417 use gpui::TestAppContext;
418 use http_client::FakeHttpClient;
419 use language_model::Role;
420 use project::Project;
421 use settings::SettingsStore;
422
423 fn init_test(cx: &mut TestAppContext) {
424 env_logger::try_init().ok();
425 cx.update(|cx| {
426 let settings_store = SettingsStore::test(cx);
427 cx.set_global(settings_store);
428 Project::init_settings(cx);
429 language::init(cx);
430
431 let http_client = FakeHttpClient::with_404_response();
432 let clock = Arc::new(clock::FakeSystemClock::new());
433 let client = Client::new(clock, http_client, cx);
434 agent::init(cx);
435 agent_settings::init(cx);
436 language_model::init(client.clone(), cx);
437 });
438 }
439
440 #[gpui::test]
441 async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
442 init_test(cx);
443 let fs = FakeFs::new(cx.executor());
444 let project = Project::test(fs, [], cx).await;
445
446 // Save a thread using the old agent.
447 {
448 let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
449 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
450 thread.update(cx, |thread, cx| {
451 thread.insert_message(
452 Role::User,
453 vec![MessageSegment::Text("Hey!".into())],
454 LoadedContext::default(),
455 vec![],
456 false,
457 cx,
458 );
459 thread.insert_message(
460 Role::Assistant,
461 vec![MessageSegment::Text("How're you doing?".into())],
462 LoadedContext::default(),
463 vec![],
464 false,
465 cx,
466 )
467 });
468 thread_store
469 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
470 .await
471 .unwrap();
472 }
473
474 let db = cx.update(|cx| ThreadsDatabase::connect(cx)).await.unwrap();
475 let threads = db.list_threads().await.unwrap();
476 assert_eq!(threads.len(), 1);
477 let thread = db
478 .load_thread(threads[0].id.clone())
479 .await
480 .unwrap()
481 .unwrap();
482 assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
483 assert_eq!(
484 thread.messages[1].to_markdown(),
485 "## Assistant\n\nHow're you doing?\n"
486 );
487 }
488}