1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::{AgentProfileId, CompletionMode};
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct DbThread {
36 pub title: SharedString,
37 pub messages: Vec<DbMessage>,
38 pub updated_at: DateTime<Utc>,
39 #[serde(default)]
40 pub detailed_summary: Option<SharedString>,
41 #[serde(default)]
42 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
43 #[serde(default)]
44 pub cumulative_token_usage: language_model::TokenUsage,
45 #[serde(default)]
46 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
47 #[serde(default)]
48 pub model: Option<DbLanguageModel>,
49 #[serde(default)]
50 pub completion_mode: Option<CompletionMode>,
51 #[serde(default)]
52 pub profile: Option<AgentProfileId>,
53}
54
55impl DbThread {
56 pub const VERSION: &'static str = "0.3.0";
57
58 pub fn from_json(json: &[u8]) -> Result<Self> {
59 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
60 match saved_thread_json.get("version") {
61 Some(serde_json::Value::String(version)) => match version.as_str() {
62 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
63 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
64 json,
65 )?),
66 },
67 _ => {
68 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
69 }
70 }
71 }
72
73 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
74 let mut messages = Vec::new();
75 let mut request_token_usage = HashMap::default();
76
77 let mut last_user_message_id = None;
78 for (ix, msg) in thread.messages.into_iter().enumerate() {
79 let message = match msg.role {
80 language_model::Role::User => {
81 let mut content = Vec::new();
82
83 // Convert segments to content
84 for segment in msg.segments {
85 match segment {
86 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
87 content.push(UserMessageContent::Text(text));
88 }
89 crate::legacy_thread::SerializedMessageSegment::Thinking {
90 text,
91 ..
92 } => {
93 // User messages don't have thinking segments, but handle gracefully
94 content.push(UserMessageContent::Text(text));
95 }
96 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
97 ..
98 } => {
99 // User messages don't have redacted thinking, skip.
100 }
101 }
102 }
103
104 // If no content was added, add context as text if available
105 if content.is_empty() && !msg.context.is_empty() {
106 content.push(UserMessageContent::Text(msg.context));
107 }
108
109 let id = UserMessageId::new();
110 last_user_message_id = Some(id.clone());
111
112 crate::Message::User(UserMessage {
113 // MessageId from old format can't be meaningfully converted, so generate a new one
114 id,
115 content,
116 })
117 }
118 language_model::Role::Assistant => {
119 let mut content = Vec::new();
120
121 // Convert segments to content
122 for segment in msg.segments {
123 match segment {
124 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
125 content.push(AgentMessageContent::Text(text));
126 }
127 crate::legacy_thread::SerializedMessageSegment::Thinking {
128 text,
129 signature,
130 } => {
131 content.push(AgentMessageContent::Thinking { text, signature });
132 }
133 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
134 data,
135 } => {
136 content.push(AgentMessageContent::RedactedThinking(data));
137 }
138 }
139 }
140
141 // Convert tool uses
142 let mut tool_names_by_id = HashMap::default();
143 for tool_use in msg.tool_uses {
144 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
145 content.push(AgentMessageContent::ToolUse(
146 language_model::LanguageModelToolUse {
147 id: tool_use.id,
148 name: tool_use.name.into(),
149 raw_input: serde_json::to_string(&tool_use.input)
150 .unwrap_or_default(),
151 input: tool_use.input,
152 is_input_complete: true,
153 },
154 ));
155 }
156
157 // Convert tool results
158 let mut tool_results = IndexMap::default();
159 for tool_result in msg.tool_results {
160 let name = tool_names_by_id
161 .remove(&tool_result.tool_use_id)
162 .unwrap_or_else(|| SharedString::from("unknown"));
163 tool_results.insert(
164 tool_result.tool_use_id.clone(),
165 language_model::LanguageModelToolResult {
166 tool_use_id: tool_result.tool_use_id,
167 tool_name: name.into(),
168 is_error: tool_result.is_error,
169 content: tool_result.content,
170 output: tool_result.output,
171 },
172 );
173 }
174
175 if let Some(last_user_message_id) = &last_user_message_id
176 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
177 {
178 request_token_usage.insert(last_user_message_id.clone(), token_usage);
179 }
180
181 crate::Message::Agent(AgentMessage {
182 content,
183 tool_results,
184 })
185 }
186 language_model::Role::System => {
187 // Skip system messages as they're not supported in the new format
188 continue;
189 }
190 };
191
192 messages.push(message);
193 }
194
195 Ok(Self {
196 title: thread.summary,
197 messages,
198 updated_at: thread.updated_at,
199 detailed_summary: match thread.detailed_summary_state {
200 crate::legacy_thread::DetailedSummaryState::NotGenerated
201 | crate::legacy_thread::DetailedSummaryState::Generating => None,
202 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
203 },
204 initial_project_snapshot: thread.initial_project_snapshot,
205 cumulative_token_usage: thread.cumulative_token_usage,
206 request_token_usage,
207 model: thread.model,
208 completion_mode: thread.completion_mode,
209 profile: thread.profile,
210 })
211 }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
215pub enum DataType {
216 #[serde(rename = "json")]
217 Json,
218 #[serde(rename = "zstd")]
219 Zstd,
220}
221
222impl Bind for DataType {
223 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
224 let value = match self {
225 DataType::Json => "json",
226 DataType::Zstd => "zstd",
227 };
228 value.bind(statement, start_index)
229 }
230}
231
232impl Column for DataType {
233 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
234 let (value, next_index) = String::column(statement, start_index)?;
235 let data_type = match value.as_str() {
236 "json" => DataType::Json,
237 "zstd" => DataType::Zstd,
238 _ => anyhow::bail!("Unknown data type: {}", value),
239 };
240 Ok((data_type, next_index))
241 }
242}
243
244pub(crate) struct ThreadsDatabase {
245 executor: BackgroundExecutor,
246 connection: Arc<Mutex<Connection>>,
247}
248
249struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
250
251impl Global for GlobalThreadsDatabase {}
252
253impl ThreadsDatabase {
254 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
255 if cx.has_global::<GlobalThreadsDatabase>() {
256 return cx.global::<GlobalThreadsDatabase>().0.clone();
257 }
258 let executor = cx.background_executor().clone();
259 let task = executor
260 .spawn({
261 let executor = executor.clone();
262 async move {
263 match ThreadsDatabase::new(executor) {
264 Ok(db) => Ok(Arc::new(db)),
265 Err(err) => Err(Arc::new(err)),
266 }
267 }
268 })
269 .shared();
270
271 cx.set_global(GlobalThreadsDatabase(task.clone()));
272 task
273 }
274
275 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
276 let connection = if *ZED_STATELESS {
277 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
278 } else if cfg!(any(feature = "test-support", test)) {
279 // rust stores the name of the test on the current thread.
280 // We use this to automatically create a database that will
281 // be shared within the test (for the test_retrieve_old_thread)
282 // but not with concurrent tests.
283 let thread = std::thread::current();
284 let test_name = thread.name();
285 Connection::open_memory(Some(&format!(
286 "THREAD_FALLBACK_{}",
287 test_name.unwrap_or_default()
288 )))
289 } else {
290 let threads_dir = paths::data_dir().join("threads");
291 std::fs::create_dir_all(&threads_dir)?;
292 let sqlite_path = threads_dir.join("threads.db");
293 Connection::open_file(&sqlite_path.to_string_lossy())
294 };
295
296 connection.exec(indoc! {"
297 CREATE TABLE IF NOT EXISTS threads (
298 id TEXT PRIMARY KEY,
299 summary TEXT NOT NULL,
300 updated_at TEXT NOT NULL,
301 data_type TEXT NOT NULL,
302 data BLOB NOT NULL
303 )
304 "})?()
305 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
306
307 let db = Self {
308 executor,
309 connection: Arc::new(Mutex::new(connection)),
310 };
311
312 Ok(db)
313 }
314
315 fn save_thread_sync(
316 connection: &Arc<Mutex<Connection>>,
317 id: acp::SessionId,
318 thread: DbThread,
319 ) -> Result<()> {
320 const COMPRESSION_LEVEL: i32 = 3;
321
322 #[derive(Serialize)]
323 struct SerializedThread {
324 #[serde(flatten)]
325 thread: DbThread,
326 version: &'static str,
327 }
328
329 let title = thread.title.to_string();
330 let updated_at = thread.updated_at.to_rfc3339();
331 let json_data = serde_json::to_string(&SerializedThread {
332 thread,
333 version: DbThread::VERSION,
334 })?;
335
336 let connection = connection.lock();
337
338 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
339 let data_type = DataType::Zstd;
340 let data = compressed;
341
342 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
343 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
344 "})?;
345
346 insert((id.0, title, updated_at, data_type, data))?;
347
348 Ok(())
349 }
350
351 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
352 let connection = self.connection.clone();
353
354 self.executor.spawn(async move {
355 let connection = connection.lock();
356
357 let mut select =
358 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
359 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
360 "})?;
361
362 let rows = select(())?;
363 let mut threads = Vec::new();
364
365 for (id, summary, updated_at) in rows {
366 threads.push(DbThreadMetadata {
367 id: acp::SessionId(id),
368 title: summary.into(),
369 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
370 });
371 }
372
373 Ok(threads)
374 })
375 }
376
377 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
378 let connection = self.connection.clone();
379
380 self.executor.spawn(async move {
381 let connection = connection.lock();
382 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
383 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
384 "})?;
385
386 let rows = select(id.0)?;
387 if let Some((data_type, data)) = rows.into_iter().next() {
388 let json_data = match data_type {
389 DataType::Zstd => {
390 let decompressed = zstd::decode_all(&data[..])?;
391 String::from_utf8(decompressed)?
392 }
393 DataType::Json => String::from_utf8(data)?,
394 };
395 let thread = DbThread::from_json(json_data.as_bytes())?;
396 Ok(Some(thread))
397 } else {
398 Ok(None)
399 }
400 })
401 }
402
403 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
404 let connection = self.connection.clone();
405
406 self.executor
407 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
408 }
409
410 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
411 let connection = self.connection.clone();
412
413 self.executor.spawn(async move {
414 let connection = connection.lock();
415
416 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
417 DELETE FROM threads WHERE id = ?
418 "})?;
419
420 delete(id.0)?;
421
422 Ok(())
423 })
424 }
425}