db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent::{thread::DetailedSummaryState, thread_store};
  4use agent_client_protocol as acp;
  5use agent_settings::{AgentProfileId, CompletionMode};
  6use anyhow::{Result, anyhow};
  7use chrono::{DateTime, Utc};
  8use collections::{HashMap, IndexMap};
  9use futures::{FutureExt, future::Shared};
 10use gpui::{BackgroundExecutor, Global, Task};
 11use indoc::indoc;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = DetailedSummaryState;
 24pub type DbLanguageModel = thread_store::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35pub struct DbThread {
 36    pub title: SharedString,
 37    pub messages: Vec<DbMessage>,
 38    pub updated_at: DateTime<Utc>,
 39    #[serde(default)]
 40    pub detailed_summary: Option<SharedString>,
 41    #[serde(default)]
 42    pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
 43    #[serde(default)]
 44    pub cumulative_token_usage: language_model::TokenUsage,
 45    #[serde(default)]
 46    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 47    #[serde(default)]
 48    pub model: Option<DbLanguageModel>,
 49    #[serde(default)]
 50    pub completion_mode: Option<CompletionMode>,
 51    #[serde(default)]
 52    pub profile: Option<AgentProfileId>,
 53}
 54
 55impl DbThread {
 56    pub const VERSION: &'static str = "0.3.0";
 57
 58    pub fn from_json(json: &[u8]) -> Result<Self> {
 59        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 60        match saved_thread_json.get("version") {
 61            Some(serde_json::Value::String(version)) => match version.as_str() {
 62                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 63                _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 64            },
 65            _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 66        }
 67    }
 68
 69    fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
 70        let mut messages = Vec::new();
 71        let mut request_token_usage = HashMap::default();
 72
 73        let mut last_user_message_id = None;
 74        for (ix, msg) in thread.messages.into_iter().enumerate() {
 75            let message = match msg.role {
 76                language_model::Role::User => {
 77                    let mut content = Vec::new();
 78
 79                    // Convert segments to content
 80                    for segment in msg.segments {
 81                        match segment {
 82                            thread_store::SerializedMessageSegment::Text { text } => {
 83                                content.push(UserMessageContent::Text(text));
 84                            }
 85                            thread_store::SerializedMessageSegment::Thinking { text, .. } => {
 86                                // User messages don't have thinking segments, but handle gracefully
 87                                content.push(UserMessageContent::Text(text));
 88                            }
 89                            thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
 90                                // User messages don't have redacted thinking, skip.
 91                            }
 92                        }
 93                    }
 94
 95                    // If no content was added, add context as text if available
 96                    if content.is_empty() && !msg.context.is_empty() {
 97                        content.push(UserMessageContent::Text(msg.context));
 98                    }
 99
100                    let id = UserMessageId::new();
101                    last_user_message_id = Some(id.clone());
102
103                    crate::Message::User(UserMessage {
104                        // MessageId from old format can't be meaningfully converted, so generate a new one
105                        id,
106                        content,
107                    })
108                }
109                language_model::Role::Assistant => {
110                    let mut content = Vec::new();
111
112                    // Convert segments to content
113                    for segment in msg.segments {
114                        match segment {
115                            thread_store::SerializedMessageSegment::Text { text } => {
116                                content.push(AgentMessageContent::Text(text));
117                            }
118                            thread_store::SerializedMessageSegment::Thinking {
119                                text,
120                                signature,
121                            } => {
122                                content.push(AgentMessageContent::Thinking { text, signature });
123                            }
124                            thread_store::SerializedMessageSegment::RedactedThinking { data } => {
125                                content.push(AgentMessageContent::RedactedThinking(data));
126                            }
127                        }
128                    }
129
130                    // Convert tool uses
131                    let mut tool_names_by_id = HashMap::default();
132                    for tool_use in msg.tool_uses {
133                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
134                        content.push(AgentMessageContent::ToolUse(
135                            language_model::LanguageModelToolUse {
136                                id: tool_use.id,
137                                name: tool_use.name.into(),
138                                raw_input: serde_json::to_string(&tool_use.input)
139                                    .unwrap_or_default(),
140                                input: tool_use.input,
141                                is_input_complete: true,
142                            },
143                        ));
144                    }
145
146                    // Convert tool results
147                    let mut tool_results = IndexMap::default();
148                    for tool_result in msg.tool_results {
149                        let name = tool_names_by_id
150                            .remove(&tool_result.tool_use_id)
151                            .unwrap_or_else(|| SharedString::from("unknown"));
152                        tool_results.insert(
153                            tool_result.tool_use_id.clone(),
154                            language_model::LanguageModelToolResult {
155                                tool_use_id: tool_result.tool_use_id,
156                                tool_name: name.into(),
157                                is_error: tool_result.is_error,
158                                content: tool_result.content,
159                                output: tool_result.output,
160                            },
161                        );
162                    }
163
164                    if let Some(last_user_message_id) = &last_user_message_id
165                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
166                    {
167                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
168                    }
169
170                    crate::Message::Agent(AgentMessage {
171                        content,
172                        tool_results,
173                    })
174                }
175                language_model::Role::System => {
176                    // Skip system messages as they're not supported in the new format
177                    continue;
178                }
179            };
180
181            messages.push(message);
182        }
183
184        Ok(Self {
185            title: thread.summary,
186            messages,
187            updated_at: thread.updated_at,
188            detailed_summary: match thread.detailed_summary_state {
189                DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => {
190                    None
191                }
192                DetailedSummaryState::Generated { text, .. } => Some(text),
193            },
194            initial_project_snapshot: thread.initial_project_snapshot,
195            cumulative_token_usage: thread.cumulative_token_usage,
196            request_token_usage,
197            model: thread.model,
198            completion_mode: thread.completion_mode,
199            profile: thread.profile,
200        })
201    }
202}
203
204pub static ZED_STATELESS: std::sync::LazyLock<bool> =
205    std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
206
207#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
208pub enum DataType {
209    #[serde(rename = "json")]
210    Json,
211    #[serde(rename = "zstd")]
212    Zstd,
213}
214
215impl Bind for DataType {
216    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
217        let value = match self {
218            DataType::Json => "json",
219            DataType::Zstd => "zstd",
220        };
221        value.bind(statement, start_index)
222    }
223}
224
225impl Column for DataType {
226    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
227        let (value, next_index) = String::column(statement, start_index)?;
228        let data_type = match value.as_str() {
229            "json" => DataType::Json,
230            "zstd" => DataType::Zstd,
231            _ => anyhow::bail!("Unknown data type: {}", value),
232        };
233        Ok((data_type, next_index))
234    }
235}
236
237pub(crate) struct ThreadsDatabase {
238    executor: BackgroundExecutor,
239    connection: Arc<Mutex<Connection>>,
240}
241
242struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
243
244impl Global for GlobalThreadsDatabase {}
245
246impl ThreadsDatabase {
247    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
248        if cx.has_global::<GlobalThreadsDatabase>() {
249            return cx.global::<GlobalThreadsDatabase>().0.clone();
250        }
251        let executor = cx.background_executor().clone();
252        let task = executor
253            .spawn({
254                let executor = executor.clone();
255                async move {
256                    match ThreadsDatabase::new(executor) {
257                        Ok(db) => Ok(Arc::new(db)),
258                        Err(err) => Err(Arc::new(err)),
259                    }
260                }
261            })
262            .shared();
263
264        cx.set_global(GlobalThreadsDatabase(task.clone()));
265        task
266    }
267
268    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
269        let connection = if *ZED_STATELESS {
270            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
271        } else if cfg!(any(feature = "test-support", test)) {
272            // rust stores the name of the test on the current thread.
273            // We use this to automatically create a database that will
274            // be shared within the test (for the test_retrieve_old_thread)
275            // but not with concurrent tests.
276            let thread = std::thread::current();
277            let test_name = thread.name();
278            Connection::open_memory(Some(&format!(
279                "THREAD_FALLBACK_{}",
280                test_name.unwrap_or_default()
281            )))
282        } else {
283            let threads_dir = paths::data_dir().join("threads");
284            std::fs::create_dir_all(&threads_dir)?;
285            let sqlite_path = threads_dir.join("threads.db");
286            Connection::open_file(&sqlite_path.to_string_lossy())
287        };
288
289        connection.exec(indoc! {"
290            CREATE TABLE IF NOT EXISTS threads (
291                id TEXT PRIMARY KEY,
292                summary TEXT NOT NULL,
293                updated_at TEXT NOT NULL,
294                data_type TEXT NOT NULL,
295                data BLOB NOT NULL
296            )
297        "})?()
298        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
299
300        let db = Self {
301            executor,
302            connection: Arc::new(Mutex::new(connection)),
303        };
304
305        Ok(db)
306    }
307
308    fn save_thread_sync(
309        connection: &Arc<Mutex<Connection>>,
310        id: acp::SessionId,
311        thread: DbThread,
312    ) -> Result<()> {
313        const COMPRESSION_LEVEL: i32 = 3;
314
315        #[derive(Serialize)]
316        struct SerializedThread {
317            #[serde(flatten)]
318            thread: DbThread,
319            version: &'static str,
320        }
321
322        let title = thread.title.to_string();
323        let updated_at = thread.updated_at.to_rfc3339();
324        let json_data = serde_json::to_string(&SerializedThread {
325            thread,
326            version: DbThread::VERSION,
327        })?;
328
329        let connection = connection.lock();
330
331        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
332        let data_type = DataType::Zstd;
333        let data = compressed;
334
335        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
336            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
337        "})?;
338
339        insert((id.0, title, updated_at, data_type, data))?;
340
341        Ok(())
342    }
343
344    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
345        let connection = self.connection.clone();
346
347        self.executor.spawn(async move {
348            let connection = connection.lock();
349
350            let mut select =
351                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
352                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
353            "})?;
354
355            let rows = select(())?;
356            let mut threads = Vec::new();
357
358            for (id, summary, updated_at) in rows {
359                threads.push(DbThreadMetadata {
360                    id: acp::SessionId(id),
361                    title: summary.into(),
362                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
363                });
364            }
365
366            Ok(threads)
367        })
368    }
369
370    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
371        let connection = self.connection.clone();
372
373        self.executor.spawn(async move {
374            let connection = connection.lock();
375            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
376                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
377            "})?;
378
379            let rows = select(id.0)?;
380            if let Some((data_type, data)) = rows.into_iter().next() {
381                let json_data = match data_type {
382                    DataType::Zstd => {
383                        let decompressed = zstd::decode_all(&data[..])?;
384                        String::from_utf8(decompressed)?
385                    }
386                    DataType::Json => String::from_utf8(data)?,
387                };
388                let thread = DbThread::from_json(json_data.as_bytes())?;
389                Ok(Some(thread))
390            } else {
391                Ok(None)
392            }
393        })
394    }
395
396    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
397        let connection = self.connection.clone();
398
399        self.executor
400            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
401    }
402
403    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
404        let connection = self.connection.clone();
405
406        self.executor.spawn(async move {
407            let connection = connection.lock();
408
409            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
410                DELETE FROM threads WHERE id = ?
411            "})?;
412
413            delete(id.0)?;
414
415            Ok(())
416        })
417    }
418}
419
420#[cfg(test)]
421mod tests {
422
423    use super::*;
424    use agent::MessageSegment;
425    use agent::context::LoadedContext;
426    use client::Client;
427    use fs::FakeFs;
428    use gpui::AppContext;
429    use gpui::TestAppContext;
430    use http_client::FakeHttpClient;
431    use language_model::Role;
432    use project::Project;
433    use settings::SettingsStore;
434
435    fn init_test(cx: &mut TestAppContext) {
436        env_logger::try_init().ok();
437        cx.update(|cx| {
438            let settings_store = SettingsStore::test(cx);
439            cx.set_global(settings_store);
440            Project::init_settings(cx);
441            language::init(cx);
442
443            let http_client = FakeHttpClient::with_404_response();
444            let clock = Arc::new(clock::FakeSystemClock::new());
445            let client = Client::new(clock, http_client, cx);
446            agent::init(cx);
447            agent_settings::init(cx);
448            language_model::init(client, cx);
449        });
450    }
451
452    #[gpui::test]
453    async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
454        init_test(cx);
455        let fs = FakeFs::new(cx.executor());
456        let project = Project::test(fs, [], cx).await;
457
458        // Save a thread using the old agent.
459        let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
460        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
461        thread.update(cx, |thread, cx| {
462            thread.insert_message(
463                Role::User,
464                vec![MessageSegment::Text("Hey!".into())],
465                LoadedContext::default(),
466                vec![],
467                false,
468                cx,
469            );
470            thread.insert_message(
471                Role::Assistant,
472                vec![MessageSegment::Text("How're you doing?".into())],
473                LoadedContext::default(),
474                vec![],
475                false,
476                cx,
477            )
478        });
479        thread_store
480            .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
481            .await
482            .unwrap();
483
484        // Open that same thread using the new agent.
485        let db = cx.update(ThreadsDatabase::connect).await.unwrap();
486        let threads = db.list_threads().await.unwrap();
487        assert_eq!(threads.len(), 1);
488        let thread = db
489            .load_thread(threads[0].id.clone())
490            .await
491            .unwrap()
492            .unwrap();
493        assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
494        assert_eq!(
495            thread.messages[1].to_markdown(),
496            "## Assistant\n\nHow're you doing?\n"
497        );
498    }
499}