1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
2use acp_thread::UserMessageId;
3use agent_client_protocol as acp;
4use agent_settings::{AgentProfileId, CompletionMode};
5use anyhow::{Result, anyhow};
6use chrono::{DateTime, Utc};
7use collections::{HashMap, IndexMap};
8use futures::{FutureExt, future::Shared};
9use gpui::{BackgroundExecutor, Global, Task};
10use indoc::indoc;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use sqlez::{
14 bindable::{Bind, Column},
15 connection::Connection,
16 statement::Statement,
17};
18use std::sync::Arc;
19use ui::{App, SharedString};
20use zed_env_vars::ZED_STATELESS;
21
22pub type DbMessage = crate::Message;
23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DbThreadMetadata {
28 pub id: acp::SessionId,
29 #[serde(alias = "summary")]
30 pub title: SharedString,
31 pub updated_at: DateTime<Utc>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct DbThread {
36 pub title: SharedString,
37 pub messages: Vec<DbMessage>,
38 pub updated_at: DateTime<Utc>,
39 #[serde(default)]
40 pub detailed_summary: Option<SharedString>,
41 #[serde(default)]
42 pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
43 #[serde(default)]
44 pub cumulative_token_usage: language_model::TokenUsage,
45 #[serde(default)]
46 pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
47 #[serde(default)]
48 pub model: Option<DbLanguageModel>,
49 #[serde(default)]
50 pub completion_mode: Option<CompletionMode>,
51 #[serde(default)]
52 pub profile: Option<AgentProfileId>,
53}
54
55impl DbThread {
56 pub const VERSION: &'static str = "0.3.0";
57
58 pub fn from_json(json: &[u8]) -> Result<Self> {
59 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
60 match saved_thread_json.get("version") {
61 Some(serde_json::Value::String(version)) => match version.as_str() {
62 Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
63 _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
64 json,
65 )?),
66 },
67 _ => {
68 Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
69 }
70 }
71 }
72
73 fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
74 let mut messages = Vec::new();
75 let mut request_token_usage = HashMap::default();
76
77 let mut last_user_message_id = None;
78 for (ix, msg) in thread.messages.into_iter().enumerate() {
79 let message = match msg.role {
80 language_model::Role::User => {
81 let mut content = Vec::new();
82
83 // Convert segments to content
84 for segment in msg.segments {
85 match segment {
86 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
87 content.push(UserMessageContent::Text(text));
88 }
89 crate::legacy_thread::SerializedMessageSegment::Thinking {
90 text,
91 ..
92 } => {
93 // User messages don't have thinking segments, but handle gracefully
94 content.push(UserMessageContent::Text(text));
95 }
96 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
97 ..
98 } => {
99 // User messages don't have redacted thinking, skip.
100 }
101 }
102 }
103
104 // If no content was added, add context as text if available
105 if content.is_empty() && !msg.context.is_empty() {
106 content.push(UserMessageContent::Text(msg.context));
107 }
108
109 let id = UserMessageId::new();
110 last_user_message_id = Some(id.clone());
111
112 crate::Message::User(UserMessage {
113 // MessageId from old format can't be meaningfully converted, so generate a new one
114 id,
115 content,
116 })
117 }
118 language_model::Role::Assistant => {
119 let mut content = Vec::new();
120
121 // Convert segments to content
122 for segment in msg.segments {
123 match segment {
124 crate::legacy_thread::SerializedMessageSegment::Text { text } => {
125 content.push(AgentMessageContent::Text(text));
126 }
127 crate::legacy_thread::SerializedMessageSegment::Thinking {
128 text,
129 signature,
130 } => {
131 content.push(AgentMessageContent::Thinking { text, signature });
132 }
133 crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
134 data,
135 } => {
136 content.push(AgentMessageContent::RedactedThinking(data));
137 }
138 }
139 }
140
141 // Convert tool uses
142 let mut tool_names_by_id = HashMap::default();
143 for tool_use in msg.tool_uses {
144 tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
145 content.push(AgentMessageContent::ToolUse(
146 language_model::LanguageModelToolUse {
147 id: tool_use.id,
148 name: tool_use.name.into(),
149 raw_input: serde_json::to_string(&tool_use.input)
150 .unwrap_or_default(),
151 input: tool_use.input,
152 is_input_complete: true,
153 thought_signature: None,
154 },
155 ));
156 }
157
158 // Convert tool results
159 let mut tool_results = IndexMap::default();
160 for tool_result in msg.tool_results {
161 let name = tool_names_by_id
162 .remove(&tool_result.tool_use_id)
163 .unwrap_or_else(|| SharedString::from("unknown"));
164 tool_results.insert(
165 tool_result.tool_use_id.clone(),
166 language_model::LanguageModelToolResult {
167 tool_use_id: tool_result.tool_use_id,
168 tool_name: name.into(),
169 is_error: tool_result.is_error,
170 content: tool_result.content,
171 output: tool_result.output,
172 },
173 );
174 }
175
176 if let Some(last_user_message_id) = &last_user_message_id
177 && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
178 {
179 request_token_usage.insert(last_user_message_id.clone(), token_usage);
180 }
181
182 crate::Message::Agent(AgentMessage {
183 content,
184 tool_results,
185 })
186 }
187 language_model::Role::System => {
188 // Skip system messages as they're not supported in the new format
189 continue;
190 }
191 };
192
193 messages.push(message);
194 }
195
196 Ok(Self {
197 title: thread.summary,
198 messages,
199 updated_at: thread.updated_at,
200 detailed_summary: match thread.detailed_summary_state {
201 crate::legacy_thread::DetailedSummaryState::NotGenerated
202 | crate::legacy_thread::DetailedSummaryState::Generating => None,
203 crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
204 },
205 initial_project_snapshot: thread.initial_project_snapshot,
206 cumulative_token_usage: thread.cumulative_token_usage,
207 request_token_usage,
208 model: thread.model,
209 completion_mode: thread.completion_mode,
210 profile: thread.profile,
211 })
212 }
213}
214
215#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
216pub enum DataType {
217 #[serde(rename = "json")]
218 Json,
219 #[serde(rename = "zstd")]
220 Zstd,
221}
222
223impl Bind for DataType {
224 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
225 let value = match self {
226 DataType::Json => "json",
227 DataType::Zstd => "zstd",
228 };
229 value.bind(statement, start_index)
230 }
231}
232
233impl Column for DataType {
234 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
235 let (value, next_index) = String::column(statement, start_index)?;
236 let data_type = match value.as_str() {
237 "json" => DataType::Json,
238 "zstd" => DataType::Zstd,
239 _ => anyhow::bail!("Unknown data type: {}", value),
240 };
241 Ok((data_type, next_index))
242 }
243}
244
245pub(crate) struct ThreadsDatabase {
246 executor: BackgroundExecutor,
247 connection: Arc<Mutex<Connection>>,
248}
249
250struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
251
252impl Global for GlobalThreadsDatabase {}
253
254impl ThreadsDatabase {
255 pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
256 if cx.has_global::<GlobalThreadsDatabase>() {
257 return cx.global::<GlobalThreadsDatabase>().0.clone();
258 }
259 let executor = cx.background_executor().clone();
260 let task = executor
261 .spawn({
262 let executor = executor.clone();
263 async move {
264 match ThreadsDatabase::new(executor) {
265 Ok(db) => Ok(Arc::new(db)),
266 Err(err) => Err(Arc::new(err)),
267 }
268 }
269 })
270 .shared();
271
272 cx.set_global(GlobalThreadsDatabase(task.clone()));
273 task
274 }
275
276 pub fn new(executor: BackgroundExecutor) -> Result<Self> {
277 let connection = if *ZED_STATELESS {
278 Connection::open_memory(Some("THREAD_FALLBACK_DB"))
279 } else if cfg!(any(feature = "test-support", test)) {
280 // rust stores the name of the test on the current thread.
281 // We use this to automatically create a database that will
282 // be shared within the test (for the test_retrieve_old_thread)
283 // but not with concurrent tests.
284 let thread = std::thread::current();
285 let test_name = thread.name();
286 Connection::open_memory(Some(&format!(
287 "THREAD_FALLBACK_{}",
288 test_name.unwrap_or_default()
289 )))
290 } else {
291 let threads_dir = paths::data_dir().join("threads");
292 std::fs::create_dir_all(&threads_dir)?;
293 let sqlite_path = threads_dir.join("threads.db");
294 Connection::open_file(&sqlite_path.to_string_lossy())
295 };
296
297 connection.exec(indoc! {"
298 CREATE TABLE IF NOT EXISTS threads (
299 id TEXT PRIMARY KEY,
300 summary TEXT NOT NULL,
301 updated_at TEXT NOT NULL,
302 data_type TEXT NOT NULL,
303 data BLOB NOT NULL
304 )
305 "})?()
306 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
307
308 let db = Self {
309 executor,
310 connection: Arc::new(Mutex::new(connection)),
311 };
312
313 Ok(db)
314 }
315
316 fn save_thread_sync(
317 connection: &Arc<Mutex<Connection>>,
318 id: acp::SessionId,
319 thread: DbThread,
320 ) -> Result<()> {
321 const COMPRESSION_LEVEL: i32 = 3;
322
323 #[derive(Serialize)]
324 struct SerializedThread {
325 #[serde(flatten)]
326 thread: DbThread,
327 version: &'static str,
328 }
329
330 let title = thread.title.to_string();
331 let updated_at = thread.updated_at.to_rfc3339();
332 let json_data = serde_json::to_string(&SerializedThread {
333 thread,
334 version: DbThread::VERSION,
335 })?;
336
337 let connection = connection.lock();
338
339 let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
340 let data_type = DataType::Zstd;
341 let data = compressed;
342
343 let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
344 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
345 "})?;
346
347 insert((id.0, title, updated_at, data_type, data))?;
348
349 Ok(())
350 }
351
352 pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
353 let connection = self.connection.clone();
354
355 self.executor.spawn(async move {
356 let connection = connection.lock();
357
358 let mut select =
359 connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
360 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
361 "})?;
362
363 let rows = select(())?;
364 let mut threads = Vec::new();
365
366 for (id, summary, updated_at) in rows {
367 threads.push(DbThreadMetadata {
368 id: acp::SessionId(id),
369 title: summary.into(),
370 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
371 });
372 }
373
374 Ok(threads)
375 })
376 }
377
378 pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
379 let connection = self.connection.clone();
380
381 self.executor.spawn(async move {
382 let connection = connection.lock();
383 let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
384 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
385 "})?;
386
387 let rows = select(id.0)?;
388 if let Some((data_type, data)) = rows.into_iter().next() {
389 let json_data = match data_type {
390 DataType::Zstd => {
391 let decompressed = zstd::decode_all(&data[..])?;
392 String::from_utf8(decompressed)?
393 }
394 DataType::Json => String::from_utf8(data)?,
395 };
396 let thread = DbThread::from_json(json_data.as_bytes())?;
397 Ok(Some(thread))
398 } else {
399 Ok(None)
400 }
401 })
402 }
403
404 pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
405 let connection = self.connection.clone();
406
407 self.executor
408 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
409 }
410
411 pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
412 let connection = self.connection.clone();
413
414 self.executor.spawn(async move {
415 let connection = connection.lock();
416
417 let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
418 DELETE FROM threads WHERE id = ?
419 "})?;
420
421 delete(id.0)?;
422
423 Ok(())
424 })
425 }
426}