db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent::{thread::DetailedSummaryState, thread_store};
  4use agent_client_protocol as acp;
  5use agent_settings::{AgentProfileId, CompletionMode};
  6use anyhow::{Result, anyhow};
  7use chrono::{DateTime, Utc};
  8use collections::{HashMap, IndexMap};
  9use futures::{FutureExt, future::Shared};
 10use gpui::{BackgroundExecutor, Global, Task};
 11use indoc::indoc;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21use zed_env_vars::ZED_STATELESS;
 22
 23pub type DbMessage = crate::Message;
 24pub type DbSummary = DetailedSummaryState;
 25pub type DbLanguageModel = thread_store::SerializedLanguageModel;
 26
 27#[derive(Debug, Clone, Serialize, Deserialize)]
 28pub struct DbThreadMetadata {
 29    pub id: acp::SessionId,
 30    #[serde(alias = "summary")]
 31    pub title: SharedString,
 32    pub updated_at: DateTime<Utc>,
 33}
 34
 35#[derive(Debug, Serialize, Deserialize)]
 36pub struct DbThread {
 37    pub title: SharedString,
 38    pub messages: Vec<DbMessage>,
 39    pub updated_at: DateTime<Utc>,
 40    #[serde(default)]
 41    pub detailed_summary: Option<SharedString>,
 42    #[serde(default)]
 43    pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
 44    #[serde(default)]
 45    pub cumulative_token_usage: language_model::TokenUsage,
 46    #[serde(default)]
 47    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 48    #[serde(default)]
 49    pub model: Option<DbLanguageModel>,
 50    #[serde(default)]
 51    pub completion_mode: Option<CompletionMode>,
 52    #[serde(default)]
 53    pub profile: Option<AgentProfileId>,
 54}
 55
 56impl DbThread {
 57    pub const VERSION: &'static str = "0.3.0";
 58
 59    pub fn from_json(json: &[u8]) -> Result<Self> {
 60        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 61        match saved_thread_json.get("version") {
 62            Some(serde_json::Value::String(version)) => match version.as_str() {
 63                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 64                _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 65            },
 66            _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 67        }
 68    }
 69
 70    fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
 71        let mut messages = Vec::new();
 72        let mut request_token_usage = HashMap::default();
 73
 74        let mut last_user_message_id = None;
 75        for (ix, msg) in thread.messages.into_iter().enumerate() {
 76            let message = match msg.role {
 77                language_model::Role::User => {
 78                    let mut content = Vec::new();
 79
 80                    // Convert segments to content
 81                    for segment in msg.segments {
 82                        match segment {
 83                            thread_store::SerializedMessageSegment::Text { text } => {
 84                                content.push(UserMessageContent::Text(text));
 85                            }
 86                            thread_store::SerializedMessageSegment::Thinking { text, .. } => {
 87                                // User messages don't have thinking segments, but handle gracefully
 88                                content.push(UserMessageContent::Text(text));
 89                            }
 90                            thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
 91                                // User messages don't have redacted thinking, skip.
 92                            }
 93                        }
 94                    }
 95
 96                    // If no content was added, add context as text if available
 97                    if content.is_empty() && !msg.context.is_empty() {
 98                        content.push(UserMessageContent::Text(msg.context));
 99                    }
100
101                    let id = UserMessageId::new();
102                    last_user_message_id = Some(id.clone());
103
104                    crate::Message::User(UserMessage {
105                        // MessageId from old format can't be meaningfully converted, so generate a new one
106                        id,
107                        content,
108                    })
109                }
110                language_model::Role::Assistant => {
111                    let mut content = Vec::new();
112
113                    // Convert segments to content
114                    for segment in msg.segments {
115                        match segment {
116                            thread_store::SerializedMessageSegment::Text { text } => {
117                                content.push(AgentMessageContent::Text(text));
118                            }
119                            thread_store::SerializedMessageSegment::Thinking {
120                                text,
121                                signature,
122                            } => {
123                                content.push(AgentMessageContent::Thinking { text, signature });
124                            }
125                            thread_store::SerializedMessageSegment::RedactedThinking { data } => {
126                                content.push(AgentMessageContent::RedactedThinking(data));
127                            }
128                        }
129                    }
130
131                    // Convert tool uses
132                    let mut tool_names_by_id = HashMap::default();
133                    for tool_use in msg.tool_uses {
134                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
135                        content.push(AgentMessageContent::ToolUse(
136                            language_model::LanguageModelToolUse {
137                                id: tool_use.id,
138                                name: tool_use.name.into(),
139                                raw_input: serde_json::to_string(&tool_use.input)
140                                    .unwrap_or_default(),
141                                input: tool_use.input,
142                                is_input_complete: true,
143                            },
144                        ));
145                    }
146
147                    // Convert tool results
148                    let mut tool_results = IndexMap::default();
149                    for tool_result in msg.tool_results {
150                        let name = tool_names_by_id
151                            .remove(&tool_result.tool_use_id)
152                            .unwrap_or_else(|| SharedString::from("unknown"));
153                        tool_results.insert(
154                            tool_result.tool_use_id.clone(),
155                            language_model::LanguageModelToolResult {
156                                tool_use_id: tool_result.tool_use_id,
157                                tool_name: name.into(),
158                                is_error: tool_result.is_error,
159                                content: tool_result.content,
160                                output: tool_result.output,
161                            },
162                        );
163                    }
164
165                    if let Some(last_user_message_id) = &last_user_message_id
166                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
167                    {
168                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
169                    }
170
171                    crate::Message::Agent(AgentMessage {
172                        content,
173                        tool_results,
174                    })
175                }
176                language_model::Role::System => {
177                    // Skip system messages as they're not supported in the new format
178                    continue;
179                }
180            };
181
182            messages.push(message);
183        }
184
185        Ok(Self {
186            title: thread.summary,
187            messages,
188            updated_at: thread.updated_at,
189            detailed_summary: match thread.detailed_summary_state {
190                DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => {
191                    None
192                }
193                DetailedSummaryState::Generated { text, .. } => Some(text),
194            },
195            initial_project_snapshot: thread.initial_project_snapshot,
196            cumulative_token_usage: thread.cumulative_token_usage,
197            request_token_usage,
198            model: thread.model,
199            completion_mode: thread.completion_mode,
200            profile: thread.profile,
201        })
202    }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
206pub enum DataType {
207    #[serde(rename = "json")]
208    Json,
209    #[serde(rename = "zstd")]
210    Zstd,
211}
212
213impl Bind for DataType {
214    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
215        let value = match self {
216            DataType::Json => "json",
217            DataType::Zstd => "zstd",
218        };
219        value.bind(statement, start_index)
220    }
221}
222
223impl Column for DataType {
224    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
225        let (value, next_index) = String::column(statement, start_index)?;
226        let data_type = match value.as_str() {
227            "json" => DataType::Json,
228            "zstd" => DataType::Zstd,
229            _ => anyhow::bail!("Unknown data type: {}", value),
230        };
231        Ok((data_type, next_index))
232    }
233}
234
235pub(crate) struct ThreadsDatabase {
236    executor: BackgroundExecutor,
237    connection: Arc<Mutex<Connection>>,
238}
239
240struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
241
242impl Global for GlobalThreadsDatabase {}
243
244impl ThreadsDatabase {
245    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
246        if cx.has_global::<GlobalThreadsDatabase>() {
247            return cx.global::<GlobalThreadsDatabase>().0.clone();
248        }
249        let executor = cx.background_executor().clone();
250        let task = executor
251            .spawn({
252                let executor = executor.clone();
253                async move {
254                    match ThreadsDatabase::new(executor) {
255                        Ok(db) => Ok(Arc::new(db)),
256                        Err(err) => Err(Arc::new(err)),
257                    }
258                }
259            })
260            .shared();
261
262        cx.set_global(GlobalThreadsDatabase(task.clone()));
263        task
264    }
265
266    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
267        let connection = if *ZED_STATELESS {
268            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
269        } else if cfg!(any(feature = "test-support", test)) {
270            // rust stores the name of the test on the current thread.
271            // We use this to automatically create a database that will
272            // be shared within the test (for the test_retrieve_old_thread)
273            // but not with concurrent tests.
274            let thread = std::thread::current();
275            let test_name = thread.name();
276            Connection::open_memory(Some(&format!(
277                "THREAD_FALLBACK_{}",
278                test_name.unwrap_or_default()
279            )))
280        } else {
281            let threads_dir = paths::data_dir().join("threads");
282            std::fs::create_dir_all(&threads_dir)?;
283            let sqlite_path = threads_dir.join("threads.db");
284            Connection::open_file(&sqlite_path.to_string_lossy())
285        };
286
287        connection.exec(indoc! {"
288            CREATE TABLE IF NOT EXISTS threads (
289                id TEXT PRIMARY KEY,
290                summary TEXT NOT NULL,
291                updated_at TEXT NOT NULL,
292                data_type TEXT NOT NULL,
293                data BLOB NOT NULL
294            )
295        "})?()
296        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
297
298        let db = Self {
299            executor,
300            connection: Arc::new(Mutex::new(connection)),
301        };
302
303        Ok(db)
304    }
305
306    fn save_thread_sync(
307        connection: &Arc<Mutex<Connection>>,
308        id: acp::SessionId,
309        thread: DbThread,
310    ) -> Result<()> {
311        const COMPRESSION_LEVEL: i32 = 3;
312
313        #[derive(Serialize)]
314        struct SerializedThread {
315            #[serde(flatten)]
316            thread: DbThread,
317            version: &'static str,
318        }
319
320        let title = thread.title.to_string();
321        let updated_at = thread.updated_at.to_rfc3339();
322        let json_data = serde_json::to_string(&SerializedThread {
323            thread,
324            version: DbThread::VERSION,
325        })?;
326
327        let connection = connection.lock();
328
329        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
330        let data_type = DataType::Zstd;
331        let data = compressed;
332
333        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
334            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
335        "})?;
336
337        insert((id.0, title, updated_at, data_type, data))?;
338
339        Ok(())
340    }
341
342    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
343        let connection = self.connection.clone();
344
345        self.executor.spawn(async move {
346            let connection = connection.lock();
347
348            let mut select =
349                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
350                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
351            "})?;
352
353            let rows = select(())?;
354            let mut threads = Vec::new();
355
356            for (id, summary, updated_at) in rows {
357                threads.push(DbThreadMetadata {
358                    id: acp::SessionId(id),
359                    title: summary.into(),
360                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
361                });
362            }
363
364            Ok(threads)
365        })
366    }
367
368    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
369        let connection = self.connection.clone();
370
371        self.executor.spawn(async move {
372            let connection = connection.lock();
373            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
374                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
375            "})?;
376
377            let rows = select(id.0)?;
378            if let Some((data_type, data)) = rows.into_iter().next() {
379                let json_data = match data_type {
380                    DataType::Zstd => {
381                        let decompressed = zstd::decode_all(&data[..])?;
382                        String::from_utf8(decompressed)?
383                    }
384                    DataType::Json => String::from_utf8(data)?,
385                };
386                let thread = DbThread::from_json(json_data.as_bytes())?;
387                Ok(Some(thread))
388            } else {
389                Ok(None)
390            }
391        })
392    }
393
394    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
395        let connection = self.connection.clone();
396
397        self.executor
398            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
399    }
400
401    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
402        let connection = self.connection.clone();
403
404        self.executor.spawn(async move {
405            let connection = connection.lock();
406
407            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
408                DELETE FROM threads WHERE id = ?
409            "})?;
410
411            delete(id.0)?;
412
413            Ok(())
414        })
415    }
416}
417
418#[cfg(test)]
419mod tests {
420
421    use super::*;
422    use agent::MessageSegment;
423    use agent::context::LoadedContext;
424    use client::Client;
425    use fs::FakeFs;
426    use gpui::AppContext;
427    use gpui::TestAppContext;
428    use http_client::FakeHttpClient;
429    use language_model::Role;
430    use project::Project;
431    use settings::SettingsStore;
432
433    fn init_test(cx: &mut TestAppContext) {
434        env_logger::try_init().ok();
435        cx.update(|cx| {
436            let settings_store = SettingsStore::test(cx);
437            cx.set_global(settings_store);
438            Project::init_settings(cx);
439            language::init(cx);
440
441            let http_client = FakeHttpClient::with_404_response();
442            let clock = Arc::new(clock::FakeSystemClock::new());
443            let client = Client::new(clock, http_client, cx);
444            agent::init(cx);
445            agent_settings::init(cx);
446            language_model::init(client, cx);
447        });
448    }
449
450    #[gpui::test]
451    async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
452        init_test(cx);
453        let fs = FakeFs::new(cx.executor());
454        let project = Project::test(fs, [], cx).await;
455
456        // Save a thread using the old agent.
457        let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
458        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
459        thread.update(cx, |thread, cx| {
460            thread.insert_message(
461                Role::User,
462                vec![MessageSegment::Text("Hey!".into())],
463                LoadedContext::default(),
464                vec![],
465                false,
466                cx,
467            );
468            thread.insert_message(
469                Role::Assistant,
470                vec![MessageSegment::Text("How're you doing?".into())],
471                LoadedContext::default(),
472                vec![],
473                false,
474                cx,
475            )
476        });
477        thread_store
478            .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
479            .await
480            .unwrap();
481
482        // Open that same thread using the new agent.
483        let db = cx.update(ThreadsDatabase::connect).await.unwrap();
484        let threads = db.list_threads().await.unwrap();
485        assert_eq!(threads.len(), 1);
486        let thread = db
487            .load_thread(threads[0].id.clone())
488            .await
489            .unwrap()
490            .unwrap();
491        assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
492        assert_eq!(
493            thread.messages[1].to_markdown(),
494            "## Assistant\n\nHow're you doing?\n"
495        );
496    }
497}