db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::{AgentProfileId, CompletionMode};
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35pub struct DbThread {
 36    pub title: SharedString,
 37    pub messages: Vec<DbMessage>,
 38    pub updated_at: DateTime<Utc>,
 39    #[serde(default)]
 40    pub detailed_summary: Option<SharedString>,
 41    #[serde(default)]
 42    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 43    #[serde(default)]
 44    pub cumulative_token_usage: language_model::TokenUsage,
 45    #[serde(default)]
 46    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 47    #[serde(default)]
 48    pub model: Option<DbLanguageModel>,
 49    #[serde(default)]
 50    pub completion_mode: Option<CompletionMode>,
 51    #[serde(default)]
 52    pub profile: Option<AgentProfileId>,
 53}
 54
 55impl DbThread {
 56    pub const VERSION: &'static str = "0.3.0";
 57
 58    pub fn from_json(json: &[u8]) -> Result<Self> {
 59        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 60        match saved_thread_json.get("version") {
 61            Some(serde_json::Value::String(version)) => match version.as_str() {
 62                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 63                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
 64                    json,
 65                )?),
 66            },
 67            _ => {
 68                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
 69            }
 70        }
 71    }
 72
 73    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
 74        let mut messages = Vec::new();
 75        let mut request_token_usage = HashMap::default();
 76
 77        let mut last_user_message_id = None;
 78        for (ix, msg) in thread.messages.into_iter().enumerate() {
 79            let message = match msg.role {
 80                language_model::Role::User => {
 81                    let mut content = Vec::new();
 82
 83                    // Convert segments to content
 84                    for segment in msg.segments {
 85                        match segment {
 86                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
 87                                content.push(UserMessageContent::Text(text));
 88                            }
 89                            crate::legacy_thread::SerializedMessageSegment::Thinking {
 90                                text,
 91                                ..
 92                            } => {
 93                                // User messages don't have thinking segments, but handle gracefully
 94                                content.push(UserMessageContent::Text(text));
 95                            }
 96                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
 97                                ..
 98                            } => {
 99                                // User messages don't have redacted thinking, skip.
100                            }
101                        }
102                    }
103
104                    // If no content was added, add context as text if available
105                    if content.is_empty() && !msg.context.is_empty() {
106                        content.push(UserMessageContent::Text(msg.context));
107                    }
108
109                    let id = UserMessageId::new();
110                    last_user_message_id = Some(id.clone());
111
112                    crate::Message::User(UserMessage {
113                        // MessageId from old format can't be meaningfully converted, so generate a new one
114                        id,
115                        content,
116                    })
117                }
118                language_model::Role::Assistant => {
119                    let mut content = Vec::new();
120
121                    // Convert segments to content
122                    for segment in msg.segments {
123                        match segment {
124                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
125                                content.push(AgentMessageContent::Text(text));
126                            }
127                            crate::legacy_thread::SerializedMessageSegment::Thinking {
128                                text,
129                                signature,
130                            } => {
131                                content.push(AgentMessageContent::Thinking { text, signature });
132                            }
133                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
134                                data,
135                            } => {
136                                content.push(AgentMessageContent::RedactedThinking(data));
137                            }
138                        }
139                    }
140
141                    // Convert tool uses
142                    let mut tool_names_by_id = HashMap::default();
143                    for tool_use in msg.tool_uses {
144                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
145                        content.push(AgentMessageContent::ToolUse(
146                            language_model::LanguageModelToolUse {
147                                id: tool_use.id,
148                                name: tool_use.name.into(),
149                                raw_input: serde_json::to_string(&tool_use.input)
150                                    .unwrap_or_default(),
151                                input: tool_use.input,
152                                is_input_complete: true,
153                                thought_signature: None,
154                            },
155                        ));
156                    }
157
158                    // Convert tool results
159                    let mut tool_results = IndexMap::default();
160                    for tool_result in msg.tool_results {
161                        let name = tool_names_by_id
162                            .remove(&tool_result.tool_use_id)
163                            .unwrap_or_else(|| SharedString::from("unknown"));
164                        tool_results.insert(
165                            tool_result.tool_use_id.clone(),
166                            language_model::LanguageModelToolResult {
167                                tool_use_id: tool_result.tool_use_id,
168                                tool_name: name.into(),
169                                is_error: tool_result.is_error,
170                                content: tool_result.content,
171                                output: tool_result.output,
172                            },
173                        );
174                    }
175
176                    if let Some(last_user_message_id) = &last_user_message_id
177                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
178                    {
179                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
180                    }
181
182                    crate::Message::Agent(AgentMessage {
183                        content,
184                        tool_results,
185                        reasoning_details: None,
186                    })
187                }
188                language_model::Role::System => {
189                    // Skip system messages as they're not supported in the new format
190                    continue;
191                }
192            };
193
194            messages.push(message);
195        }
196
197        Ok(Self {
198            title: thread.summary,
199            messages,
200            updated_at: thread.updated_at,
201            detailed_summary: match thread.detailed_summary_state {
202                crate::legacy_thread::DetailedSummaryState::NotGenerated
203                | crate::legacy_thread::DetailedSummaryState::Generating => None,
204                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
205            },
206            initial_project_snapshot: thread.initial_project_snapshot,
207            cumulative_token_usage: thread.cumulative_token_usage,
208            request_token_usage,
209            model: thread.model,
210            completion_mode: thread.completion_mode,
211            profile: thread.profile,
212        })
213    }
214}
215
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
217pub enum DataType {
218    #[serde(rename = "json")]
219    Json,
220    #[serde(rename = "zstd")]
221    Zstd,
222}
223
224impl Bind for DataType {
225    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
226        let value = match self {
227            DataType::Json => "json",
228            DataType::Zstd => "zstd",
229        };
230        value.bind(statement, start_index)
231    }
232}
233
234impl Column for DataType {
235    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
236        let (value, next_index) = String::column(statement, start_index)?;
237        let data_type = match value.as_str() {
238            "json" => DataType::Json,
239            "zstd" => DataType::Zstd,
240            _ => anyhow::bail!("Unknown data type: {}", value),
241        };
242        Ok((data_type, next_index))
243    }
244}
245
246pub(crate) struct ThreadsDatabase {
247    executor: BackgroundExecutor,
248    connection: Arc<Mutex<Connection>>,
249}
250
251struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
252
253impl Global for GlobalThreadsDatabase {}
254
255impl ThreadsDatabase {
256    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
257        if cx.has_global::<GlobalThreadsDatabase>() {
258            return cx.global::<GlobalThreadsDatabase>().0.clone();
259        }
260        let executor = cx.background_executor().clone();
261        let task = executor
262            .spawn({
263                let executor = executor.clone();
264                async move {
265                    match ThreadsDatabase::new(executor) {
266                        Ok(db) => Ok(Arc::new(db)),
267                        Err(err) => Err(Arc::new(err)),
268                    }
269                }
270            })
271            .shared();
272
273        cx.set_global(GlobalThreadsDatabase(task.clone()));
274        task
275    }
276
277    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
278        let connection = if *ZED_STATELESS {
279            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
280        } else if cfg!(any(feature = "test-support", test)) {
281            // rust stores the name of the test on the current thread.
282            // We use this to automatically create a database that will
283            // be shared within the test (for the test_retrieve_old_thread)
284            // but not with concurrent tests.
285            let thread = std::thread::current();
286            let test_name = thread.name();
287            Connection::open_memory(Some(&format!(
288                "THREAD_FALLBACK_{}",
289                test_name.unwrap_or_default()
290            )))
291        } else {
292            let threads_dir = paths::data_dir().join("threads");
293            std::fs::create_dir_all(&threads_dir)?;
294            let sqlite_path = threads_dir.join("threads.db");
295            Connection::open_file(&sqlite_path.to_string_lossy())
296        };
297
298        connection.exec(indoc! {"
299            CREATE TABLE IF NOT EXISTS threads (
300                id TEXT PRIMARY KEY,
301                summary TEXT NOT NULL,
302                updated_at TEXT NOT NULL,
303                data_type TEXT NOT NULL,
304                data BLOB NOT NULL
305            )
306        "})?()
307        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
308
309        let db = Self {
310            executor,
311            connection: Arc::new(Mutex::new(connection)),
312        };
313
314        Ok(db)
315    }
316
317    fn save_thread_sync(
318        connection: &Arc<Mutex<Connection>>,
319        id: acp::SessionId,
320        thread: DbThread,
321    ) -> Result<()> {
322        const COMPRESSION_LEVEL: i32 = 3;
323
324        #[derive(Serialize)]
325        struct SerializedThread {
326            #[serde(flatten)]
327            thread: DbThread,
328            version: &'static str,
329        }
330
331        let title = thread.title.to_string();
332        let updated_at = thread.updated_at.to_rfc3339();
333        let json_data = serde_json::to_string(&SerializedThread {
334            thread,
335            version: DbThread::VERSION,
336        })?;
337
338        let connection = connection.lock();
339
340        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
341        let data_type = DataType::Zstd;
342        let data = compressed;
343
344        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
345            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
346        "})?;
347
348        insert((id.0, title, updated_at, data_type, data))?;
349
350        Ok(())
351    }
352
353    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
354        let connection = self.connection.clone();
355
356        self.executor.spawn(async move {
357            let connection = connection.lock();
358
359            let mut select =
360                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
361                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
362            "})?;
363
364            let rows = select(())?;
365            let mut threads = Vec::new();
366
367            for (id, summary, updated_at) in rows {
368                threads.push(DbThreadMetadata {
369                    id: acp::SessionId(id),
370                    title: summary.into(),
371                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
372                });
373            }
374
375            Ok(threads)
376        })
377    }
378
379    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
380        let connection = self.connection.clone();
381
382        self.executor.spawn(async move {
383            let connection = connection.lock();
384            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
385                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
386            "})?;
387
388            let rows = select(id.0)?;
389            if let Some((data_type, data)) = rows.into_iter().next() {
390                let json_data = match data_type {
391                    DataType::Zstd => {
392                        let decompressed = zstd::decode_all(&data[..])?;
393                        String::from_utf8(decompressed)?
394                    }
395                    DataType::Json => String::from_utf8(data)?,
396                };
397                let thread = DbThread::from_json(json_data.as_bytes())?;
398                Ok(Some(thread))
399            } else {
400                Ok(None)
401            }
402        })
403    }
404
405    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
406        let connection = self.connection.clone();
407
408        self.executor
409            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
410    }
411
412    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
413        let connection = self.connection.clone();
414
415        self.executor.spawn(async move {
416            let connection = connection.lock();
417
418            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
419                DELETE FROM threads WHERE id = ?
420            "})?;
421
422            delete(id.0)?;
423
424            Ok(())
425        })
426    }
427
428    pub fn delete_threads(&self) -> Task<Result<()>> {
429        let connection = self.connection.clone();
430
431        self.executor.spawn(async move {
432            let connection = connection.lock();
433
434            let mut delete = connection.exec_bound::<()>(indoc! {"
435                DELETE FROM threads
436            "})?;
437
438            delete(())?;
439
440            Ok(())
441        })
442    }
443}