1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use agent::thread_store;
3use agent_client_protocol as acp;
4use agent_settings::{AgentProfileId, CompletionMode};
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20
21pub type DbMessage = crate::Message;
22pub type DbSummary = agent::thread::DetailedSummaryState;
23pub type DbLanguageModel = thread_store::SerializedLanguageModel;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DbThreadMetadata {
27 pub id: acp::SessionId,
28 #[serde(alias = "summary")]
29 pub title: SharedString,
30 pub updated_at: DateTime<Utc>,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34pub struct DbThread {
35 pub title: SharedString,
36 pub messages: Vec<DbMessage>,
37 pub updated_at: DateTime<Utc>,
38 #[serde(default)]
39 pub summary: DbSummary,
40 #[serde(default)]
41 pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
42 #[serde(default)]
43 pub cumulative_token_usage: language_model::TokenUsage,
44 #[serde(default)]
45 pub request_token_usage: Vec<language_model::TokenUsage>,
46 #[serde(default)]
47 pub model: Option<DbLanguageModel>,
48 #[serde(default)]
49 pub completion_mode: Option<CompletionMode>,
50 #[serde(default)]
51 pub profile: Option<AgentProfileId>,
52}
53
54impl DbThread {
55 pub const VERSION: &'static str = "0.3.0";
56
57 pub fn from_json(json: &[u8]) -> Result<Self> {
58 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
59 match saved_thread_json.get("version") {
60 Some(serde_json::Value::String(version)) => match version.as_str() {
61 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
62 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
63 },
64 _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
65 }
66 }
67
68 fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
69 let mut messages = Vec::new();
70 for msg in thread.messages {
71 let message = match msg.role {
72 language_model::Role::User => {
73 let mut content = Vec::new();
74
75 // Convert segments to content
76 for segment in msg.segments {
77 match segment {
78 thread_store::SerializedMessageSegment::Text { text } => {
79 content.push(UserMessageContent::Text(text));
80 }
81 thread_store::SerializedMessageSegment::Thinking { text, .. } => {
82 // User messages don't have thinking segments, but handle gracefully
83 content.push(UserMessageContent::Text(text));
84 }
85 thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
86 // User messages don't have redacted thinking, skip.
87 }
88 }
89 }
90
91 // If no content was added, add context as text if available
92 if content.is_empty() && !msg.context.is_empty() {
93 content.push(UserMessageContent::Text(msg.context));
94 }
95
96 crate::Message::User(UserMessage {
97 // MessageId from old format can't be meaningfully converted, so generate a new one
98 id: acp_thread::UserMessageId::new(),
99 content,
100 })
101 }
102 language_model::Role::Assistant => {
103 let mut content = Vec::new();
104
105 // Convert segments to content
106 for segment in msg.segments {
107 match segment {
108 thread_store::SerializedMessageSegment::Text { text } => {
109 content.push(AgentMessageContent::Text(text));
110 }
111 thread_store::SerializedMessageSegment::Thinking {
112 text,
113 signature,
114 } => {
115 content.push(AgentMessageContent::Thinking { text, signature });
116 }
117 thread_store::SerializedMessageSegment::RedactedThinking { data } => {
118 content.push(AgentMessageContent::RedactedThinking(data));
119 }
120 }
121 }
122
123 // Convert tool uses
124 let mut tool_names_by_id = HashMap::default();
125 for tool_use in msg.tool_uses {
126 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
127 content.push(AgentMessageContent::ToolUse(
128 language_model::LanguageModelToolUse {
129 id: tool_use.id,
130 name: tool_use.name.into(),
131 raw_input: serde_json::to_string(&tool_use.input)
132 .unwrap_or_default(),
133 input: tool_use.input,
134 is_input_complete: true,
135 },
136 ));
137 }
138
139 // Convert tool results
140 let mut tool_results = IndexMap::default();
141 for tool_result in msg.tool_results {
142 let name = tool_names_by_id
143 .remove(&tool_result.tool_use_id)
144 .unwrap_or_else(|| SharedString::from("unknown"));
145 tool_results.insert(
146 tool_result.tool_use_id.clone(),
147 language_model::LanguageModelToolResult {
148 tool_use_id: tool_result.tool_use_id,
149 tool_name: name.into(),
150 is_error: tool_result.is_error,
151 content: tool_result.content,
152 output: tool_result.output,
153 },
154 );
155 }
156
157 crate::Message::Agent(AgentMessage {
158 content,
159 tool_results,
160 })
161 }
162 language_model::Role::System => {
163 // Skip system messages as they're not supported in the new format
164 continue;
165 }
166 };
167
168 messages.push(message);
169 }
170
171 Ok(Self {
172 title: thread.summary,
173 messages,
174 updated_at: thread.updated_at,
175 summary: thread.detailed_summary_state,
176 initial_project_snapshot: thread.initial_project_snapshot,
177 cumulative_token_usage: thread.cumulative_token_usage,
178 request_token_usage: thread.request_token_usage,
179 model: thread.model,
180 completion_mode: thread.completion_mode,
181 profile: thread.profile,
182 })
183 }
184}
185
186pub static ZED_STATELESS: std::sync::LazyLock<bool> =
187 std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
188
189#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
190pub enum DataType {
191 #[serde(rename = "json")]
192 Json,
193 #[serde(rename = "zstd")]
194 Zstd,
195}
196
197impl Bind for DataType {
198 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
199 let value = match self {
200 DataType::Json => "json",
201 DataType::Zstd => "zstd",
202 };
203 value.bind(statement, start_index)
204 }
205}
206
207impl Column for DataType {
208 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
209 let (value, next_index) = String::column(statement, start_index)?;
210 let data_type = match value.as_str() {
211 "json" => DataType::Json,
212 "zstd" => DataType::Zstd,
213 _ => anyhow::bail!("Unknown data type: {}", value),
214 };
215 Ok((data_type, next_index))
216 }
217}
218
219pub(crate) struct ThreadsDatabase {
220 executor: BackgroundExecutor,
221 connection: Arc<Mutex<Connection>>,
222}
223
224struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
225
226impl Global for GlobalThreadsDatabase {}
227
228impl ThreadsDatabase {
229 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
230 if cx.has_global::<GlobalThreadsDatabase>() {
231 return cx.global::<GlobalThreadsDatabase>().0.clone();
232 }
233 let executor = cx.background_executor().clone();
234 let task = executor
235 .spawn({
236 let executor = executor.clone();
237 async move {
238 match ThreadsDatabase::new(executor) {
239 Ok(db) => Ok(Arc::new(db)),
240 Err(err) => Err(Arc::new(err)),
241 }
242 }
243 })
244 .shared();
245
246 cx.set_global(GlobalThreadsDatabase(task.clone()));
247 task
248 }
249
250 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
251 let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
252 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
253 } else {
254 let threads_dir = paths::data_dir().join("threads");
255 std::fs::create_dir_all(&threads_dir)?;
256 let sqlite_path = threads_dir.join("threads.db");
257 Connection::open_file(&sqlite_path.to_string_lossy())
258 };
259
260 connection.exec(indoc! {"
261 CREATE TABLE IF NOT EXISTS threads (
262 id TEXT PRIMARY KEY,
263 summary TEXT NOT NULL,
264 updated_at TEXT NOT NULL,
265 data_type TEXT NOT NULL,
266 data BLOB NOT NULL
267 )
268 "})?()
269 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
270
271 let db = Self {
272 executor: executor.clone(),
273 connection: Arc::new(Mutex::new(connection)),
274 };
275
276 Ok(db)
277 }
278
279 fn save_thread_sync(
280 connection: &Arc<Mutex<Connection>>,
281 id: acp::SessionId,
282 thread: DbThread,
283 ) -> Result<()> {
284 const COMPRESSION_LEVEL: i32 = 3;
285
286 #[derive(Serialize)]
287 struct SerializedThread {
288 #[serde(flatten)]
289 thread: DbThread,
290 version: &'static str,
291 }
292
293 let title = thread.title.to_string();
294 let updated_at = thread.updated_at.to_rfc3339();
295 let json_data = serde_json::to_string(&SerializedThread {
296 thread,
297 version: DbThread::VERSION,
298 })?;
299
300 let connection = connection.lock();
301
302 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
303 let data_type = DataType::Zstd;
304 let data = compressed;
305
306 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
307 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
308 "})?;
309
310 insert((id.0.clone(), title, updated_at, data_type, data))?;
311
312 Ok(())
313 }
314
315 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
316 let connection = self.connection.clone();
317
318 self.executor.spawn(async move {
319 let connection = connection.lock();
320
321 let mut select =
322 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
323 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
324 "})?;
325
326 let rows = select(())?;
327 let mut threads = Vec::new();
328
329 for (id, summary, updated_at) in rows {
330 threads.push(DbThreadMetadata {
331 id: acp::SessionId(id),
332 title: summary.into(),
333 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
334 });
335 }
336
337 Ok(threads)
338 })
339 }
340
341 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
342 let connection = self.connection.clone();
343
344 self.executor.spawn(async move {
345 let connection = connection.lock();
346 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
347 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
348 "})?;
349
350 let rows = select(id.0)?;
351 if let Some((data_type, data)) = rows.into_iter().next() {
352 let json_data = match data_type {
353 DataType::Zstd => {
354 let decompressed = zstd::decode_all(&data[..])?;
355 String::from_utf8(decompressed)?
356 }
357 DataType::Json => String::from_utf8(data)?,
358 };
359 let thread = DbThread::from_json(json_data.as_bytes())?;
360 Ok(Some(thread))
361 } else {
362 Ok(None)
363 }
364 })
365 }
366
367 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
368 let connection = self.connection.clone();
369
370 self.executor
371 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
372 }
373
374 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
375 let connection = self.connection.clone();
376
377 self.executor.spawn(async move {
378 let connection = connection.lock();
379
380 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
381 DELETE FROM threads WHERE id = ?
382 "})?;
383
384 delete(id.0)?;
385
386 Ok(())
387 })
388 }
389}
390
391#[cfg(test)]
392mod tests {
393
394 use super::*;
395 use agent::MessageSegment;
396 use agent::context::LoadedContext;
397 use client::Client;
398 use fs::FakeFs;
399 use gpui::AppContext;
400 use gpui::TestAppContext;
401 use http_client::FakeHttpClient;
402 use language_model::Role;
403 use project::Project;
404 use settings::SettingsStore;
405
406 fn init_test(cx: &mut TestAppContext) {
407 env_logger::try_init().ok();
408 cx.update(|cx| {
409 let settings_store = SettingsStore::test(cx);
410 cx.set_global(settings_store);
411 Project::init_settings(cx);
412 language::init(cx);
413
414 let http_client = FakeHttpClient::with_404_response();
415 let clock = Arc::new(clock::FakeSystemClock::new());
416 let client = Client::new(clock, http_client, cx);
417 agent::init(cx);
418 agent_settings::init(cx);
419 language_model::init(client.clone(), cx);
420 });
421 }
422
423 #[gpui::test]
424 async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
425 init_test(cx);
426 let fs = FakeFs::new(cx.executor());
427 let project = Project::test(fs, [], cx).await;
428
429 // Save a thread using the old agent.
430 let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
431 let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
432 thread.update(cx, |thread, cx| {
433 thread.insert_message(
434 Role::User,
435 vec![MessageSegment::Text("Hey!".into())],
436 LoadedContext::default(),
437 vec![],
438 false,
439 cx,
440 );
441 thread.insert_message(
442 Role::Assistant,
443 vec![MessageSegment::Text("How're you doing?".into())],
444 LoadedContext::default(),
445 vec![],
446 false,
447 cx,
448 )
449 });
450 thread_store
451 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
452 .await
453 .unwrap();
454
455 // Open that same thread using the new agent.
456 let db = cx.update(ThreadsDatabase::connect).await.unwrap();
457 let threads = db.list_threads().await.unwrap();
458 assert_eq!(threads.len(), 1);
459 let thread = db
460 .load_thread(threads[0].id.clone())
461 .await
462 .unwrap()
463 .unwrap();
464 assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
465 assert_eq!(
466 thread.messages[1].to_markdown(),
467 "## Assistant\n\nHow're you doing?\n"
468 );
469 }
470}