1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::{AgentProfileId, CompletionMode};
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct DbThread {
36 pub title: SharedString,
37 pub messages: Vec<DbMessage>,
38 pub updated_at: DateTime<Utc>,
39 #[serde(default)]
40 pub detailed_summary: Option<SharedString>,
41 #[serde(default)]
42 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
43 #[serde(default)]
44 pub cumulative_token_usage: language_model::TokenUsage,
45 #[serde(default)]
46 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
47 #[serde(default)]
48 pub model: Option<DbLanguageModel>,
49 #[serde(default)]
50 pub completion_mode: Option<CompletionMode>,
51 #[serde(default)]
52 pub profile: Option<AgentProfileId>,
53}
54
55impl DbThread {
56 pub const VERSION: &'static str = "0.3.0";
57
58 pub fn from_json(json: &[u8]) -> Result<Self> {
59 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
60 match saved_thread_json.get("version") {
61 Some(serde_json::Value::String(version)) => match version.as_str() {
62 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
63 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
64 json,
65 )?),
66 },
67 _ => {
68 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
69 }
70 }
71 }
72
73 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
74 let mut messages = Vec::new();
75 let mut request_token_usage = HashMap::default();
76
77 let mut last_user_message_id = None;
78 for (ix, msg) in thread.messages.into_iter().enumerate() {
79 let message = match msg.role {
80 language_model::Role::User => {
81 let mut content = Vec::new();
82
83 // Convert segments to content
84 for segment in msg.segments {
85 match segment {
86 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
87 content.push(UserMessageContent::Text(text));
88 }
89 crate::legacy_thread::SerializedMessageSegment::Thinking {
90 text,
91 ..
92 } => {
93 // User messages don't have thinking segments, but handle gracefully
94 content.push(UserMessageContent::Text(text));
95 }
96 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
97 ..
98 } => {
99 // User messages don't have redacted thinking, skip.
100 }
101 }
102 }
103
104 // If no content was added, add context as text if available
105 if content.is_empty() && !msg.context.is_empty() {
106 content.push(UserMessageContent::Text(msg.context));
107 }
108
109 let id = UserMessageId::new();
110 last_user_message_id = Some(id.clone());
111
112 crate::Message::User(UserMessage {
113 // MessageId from old format can't be meaningfully converted, so generate a new one
114 id,
115 content,
116 })
117 }
118 language_model::Role::Assistant => {
119 let mut content = Vec::new();
120
121 // Convert segments to content
122 for segment in msg.segments {
123 match segment {
124 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
125 content.push(AgentMessageContent::Text(text));
126 }
127 crate::legacy_thread::SerializedMessageSegment::Thinking {
128 text,
129 signature,
130 } => {
131 content.push(AgentMessageContent::Thinking { text, signature });
132 }
133 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
134 data,
135 } => {
136 content.push(AgentMessageContent::RedactedThinking(data));
137 }
138 }
139 }
140
141 // Convert tool uses
142 let mut tool_names_by_id = HashMap::default();
143 for tool_use in msg.tool_uses {
144 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
145 content.push(AgentMessageContent::ToolUse(
146 language_model::LanguageModelToolUse {
147 id: tool_use.id,
148 name: tool_use.name.into(),
149 raw_input: serde_json::to_string(&tool_use.input)
150 .unwrap_or_default(),
151 input: tool_use.input,
152 is_input_complete: true,
153 thought_signature: None,
154 },
155 ));
156 }
157
158 // Convert tool results
159 let mut tool_results = IndexMap::default();
160 for tool_result in msg.tool_results {
161 let name = tool_names_by_id
162 .remove(&tool_result.tool_use_id)
163 .unwrap_or_else(|| SharedString::from("unknown"));
164 tool_results.insert(
165 tool_result.tool_use_id.clone(),
166 language_model::LanguageModelToolResult {
167 tool_use_id: tool_result.tool_use_id,
168 tool_name: name.into(),
169 is_error: tool_result.is_error,
170 content: tool_result.content,
171 output: tool_result.output,
172 },
173 );
174 }
175
176 if let Some(last_user_message_id) = &last_user_message_id
177 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
178 {
179 request_token_usage.insert(last_user_message_id.clone(), token_usage);
180 }
181
182 crate::Message::Agent(AgentMessage {
183 content,
184 tool_results,
185 reasoning_details: None,
186 })
187 }
188 language_model::Role::System => {
189 // Skip system messages as they're not supported in the new format
190 continue;
191 }
192 };
193
194 messages.push(message);
195 }
196
197 Ok(Self {
198 title: thread.summary,
199 messages,
200 updated_at: thread.updated_at,
201 detailed_summary: match thread.detailed_summary_state {
202 crate::legacy_thread::DetailedSummaryState::NotGenerated
203 | crate::legacy_thread::DetailedSummaryState::Generating => None,
204 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
205 },
206 initial_project_snapshot: thread.initial_project_snapshot,
207 cumulative_token_usage: thread.cumulative_token_usage,
208 request_token_usage,
209 model: thread.model,
210 completion_mode: thread.completion_mode,
211 profile: thread.profile,
212 })
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
217pub enum DataType {
218 #[serde(rename = "json")]
219 Json,
220 #[serde(rename = "zstd")]
221 Zstd,
222}
223
224impl Bind for DataType {
225 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
226 let value = match self {
227 DataType::Json => "json",
228 DataType::Zstd => "zstd",
229 };
230 value.bind(statement, start_index)
231 }
232}
233
234impl Column for DataType {
235 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
236 let (value, next_index) = String::column(statement, start_index)?;
237 let data_type = match value.as_str() {
238 "json" => DataType::Json,
239 "zstd" => DataType::Zstd,
240 _ => anyhow::bail!("Unknown data type: {}", value),
241 };
242 Ok((data_type, next_index))
243 }
244}
245
246pub(crate) struct ThreadsDatabase {
247 executor: BackgroundExecutor,
248 connection: Arc<Mutex<Connection>>,
249}
250
251struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
252
253impl Global for GlobalThreadsDatabase {}
254
255impl ThreadsDatabase {
256 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
257 if cx.has_global::<GlobalThreadsDatabase>() {
258 return cx.global::<GlobalThreadsDatabase>().0.clone();
259 }
260 let executor = cx.background_executor().clone();
261 let task = executor
262 .spawn({
263 let executor = executor.clone();
264 async move {
265 match ThreadsDatabase::new(executor) {
266 Ok(db) => Ok(Arc::new(db)),
267 Err(err) => Err(Arc::new(err)),
268 }
269 }
270 })
271 .shared();
272
273 cx.set_global(GlobalThreadsDatabase(task.clone()));
274 task
275 }
276
277 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
278 let connection = if *ZED_STATELESS {
279 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
280 } else if cfg!(any(feature = "test-support", test)) {
281 // rust stores the name of the test on the current thread.
282 // We use this to automatically create a database that will
283 // be shared within the test (for the test_retrieve_old_thread)
284 // but not with concurrent tests.
285 let thread = std::thread::current();
286 let test_name = thread.name();
287 Connection::open_memory(Some(&format!(
288 "THREAD_FALLBACK_{}",
289 test_name.unwrap_or_default()
290 )))
291 } else {
292 let threads_dir = paths::data_dir().join("threads");
293 std::fs::create_dir_all(&threads_dir)?;
294 let sqlite_path = threads_dir.join("threads.db");
295 Connection::open_file(&sqlite_path.to_string_lossy())
296 };
297
298 connection.exec(indoc! {"
299 CREATE TABLE IF NOT EXISTS threads (
300 id TEXT PRIMARY KEY,
301 summary TEXT NOT NULL,
302 updated_at TEXT NOT NULL,
303 data_type TEXT NOT NULL,
304 data BLOB NOT NULL
305 )
306 "})?()
307 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
308
309 let db = Self {
310 executor,
311 connection: Arc::new(Mutex::new(connection)),
312 };
313
314 Ok(db)
315 }
316
317 fn save_thread_sync(
318 connection: &Arc<Mutex<Connection>>,
319 id: acp::SessionId,
320 thread: DbThread,
321 ) -> Result<()> {
322 const COMPRESSION_LEVEL: i32 = 3;
323
324 #[derive(Serialize)]
325 struct SerializedThread {
326 #[serde(flatten)]
327 thread: DbThread,
328 version: &'static str,
329 }
330
331 let title = thread.title.to_string();
332 let updated_at = thread.updated_at.to_rfc3339();
333 let json_data = serde_json::to_string(&SerializedThread {
334 thread,
335 version: DbThread::VERSION,
336 })?;
337
338 let connection = connection.lock();
339
340 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
341 let data_type = DataType::Zstd;
342 let data = compressed;
343
344 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
345 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
346 "})?;
347
348 insert((id.0, title, updated_at, data_type, data))?;
349
350 Ok(())
351 }
352
353 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
354 let connection = self.connection.clone();
355
356 self.executor.spawn(async move {
357 let connection = connection.lock();
358
359 let mut select =
360 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
361 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
362 "})?;
363
364 let rows = select(())?;
365 let mut threads = Vec::new();
366
367 for (id, summary, updated_at) in rows {
368 threads.push(DbThreadMetadata {
369 id: acp::SessionId(id),
370 title: summary.into(),
371 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
372 });
373 }
374
375 Ok(threads)
376 })
377 }
378
379 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
380 let connection = self.connection.clone();
381
382 self.executor.spawn(async move {
383 let connection = connection.lock();
384 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
385 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
386 "})?;
387
388 let rows = select(id.0)?;
389 if let Some((data_type, data)) = rows.into_iter().next() {
390 let json_data = match data_type {
391 DataType::Zstd => {
392 let decompressed = zstd::decode_all(&data[..])?;
393 String::from_utf8(decompressed)?
394 }
395 DataType::Json => String::from_utf8(data)?,
396 };
397 let thread = DbThread::from_json(json_data.as_bytes())?;
398 Ok(Some(thread))
399 } else {
400 Ok(None)
401 }
402 })
403 }
404
405 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
406 let connection = self.connection.clone();
407
408 self.executor
409 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
410 }
411
412 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
413 let connection = self.connection.clone();
414
415 self.executor.spawn(async move {
416 let connection = connection.lock();
417
418 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
419 DELETE FROM threads WHERE id = ?
420 "})?;
421
422 delete(id.0)?;
423
424 Ok(())
425 })
426 }
427
428 pub fn delete_threads(&self) -> Task<Result<()>> {
429 let connection = self.connection.clone();
430
431 self.executor.spawn(async move {
432 let connection = connection.lock();
433
434 let mut delete = connection.exec_bound::<()>(indoc! {"
435 DELETE FROM threads
436 "})?;
437
438 delete(())?;
439
440 Ok(())
441 })
442 }
443}