db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use agent::thread_store;
  3use agent_client_protocol as acp;
  4use agent_settings::{AgentProfileId, CompletionMode};
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20
 21pub type DbMessage = crate::Message;
 22pub type DbSummary = agent::thread::DetailedSummaryState;
 23pub type DbLanguageModel = thread_store::SerializedLanguageModel;
 24
 25#[derive(Debug, Clone, Serialize, Deserialize)]
 26pub struct DbThreadMetadata {
 27    pub id: acp::SessionId,
 28    #[serde(alias = "summary")]
 29    pub title: SharedString,
 30    pub updated_at: DateTime<Utc>,
 31}
 32
 33#[derive(Debug, Serialize, Deserialize)]
 34pub struct DbThread {
 35    pub title: SharedString,
 36    pub messages: Vec<DbMessage>,
 37    pub updated_at: DateTime<Utc>,
 38    #[serde(default)]
 39    pub summary: DbSummary,
 40    #[serde(default)]
 41    pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
 42    #[serde(default)]
 43    pub cumulative_token_usage: language_model::TokenUsage,
 44    #[serde(default)]
 45    pub request_token_usage: Vec<language_model::TokenUsage>,
 46    #[serde(default)]
 47    pub model: Option<DbLanguageModel>,
 48    #[serde(default)]
 49    pub completion_mode: Option<CompletionMode>,
 50    #[serde(default)]
 51    pub profile: Option<AgentProfileId>,
 52}
 53
 54impl DbThread {
 55    pub const VERSION: &'static str = "0.3.0";
 56
 57    pub fn from_json(json: &[u8]) -> Result<Self> {
 58        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 59        match saved_thread_json.get("version") {
 60            Some(serde_json::Value::String(version)) => match version.as_str() {
 61                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 62                _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 63            },
 64            _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 65        }
 66    }
 67
 68    fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
 69        let mut messages = Vec::new();
 70        for msg in thread.messages {
 71            let message = match msg.role {
 72                language_model::Role::User => {
 73                    let mut content = Vec::new();
 74
 75                    // Convert segments to content
 76                    for segment in msg.segments {
 77                        match segment {
 78                            thread_store::SerializedMessageSegment::Text { text } => {
 79                                content.push(UserMessageContent::Text(text));
 80                            }
 81                            thread_store::SerializedMessageSegment::Thinking { text, .. } => {
 82                                // User messages don't have thinking segments, but handle gracefully
 83                                content.push(UserMessageContent::Text(text));
 84                            }
 85                            thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
 86                                // User messages don't have redacted thinking, skip.
 87                            }
 88                        }
 89                    }
 90
 91                    // If no content was added, add context as text if available
 92                    if content.is_empty() && !msg.context.is_empty() {
 93                        content.push(UserMessageContent::Text(msg.context));
 94                    }
 95
 96                    crate::Message::User(UserMessage {
 97                        // MessageId from old format can't be meaningfully converted, so generate a new one
 98                        id: acp_thread::UserMessageId::new(),
 99                        content,
100                    })
101                }
102                language_model::Role::Assistant => {
103                    let mut content = Vec::new();
104
105                    // Convert segments to content
106                    for segment in msg.segments {
107                        match segment {
108                            thread_store::SerializedMessageSegment::Text { text } => {
109                                content.push(AgentMessageContent::Text(text));
110                            }
111                            thread_store::SerializedMessageSegment::Thinking {
112                                text,
113                                signature,
114                            } => {
115                                content.push(AgentMessageContent::Thinking { text, signature });
116                            }
117                            thread_store::SerializedMessageSegment::RedactedThinking { data } => {
118                                content.push(AgentMessageContent::RedactedThinking(data));
119                            }
120                        }
121                    }
122
123                    // Convert tool uses
124                    let mut tool_names_by_id = HashMap::default();
125                    for tool_use in msg.tool_uses {
126                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
127                        content.push(AgentMessageContent::ToolUse(
128                            language_model::LanguageModelToolUse {
129                                id: tool_use.id,
130                                name: tool_use.name.into(),
131                                raw_input: serde_json::to_string(&tool_use.input)
132                                    .unwrap_or_default(),
133                                input: tool_use.input,
134                                is_input_complete: true,
135                            },
136                        ));
137                    }
138
139                    // Convert tool results
140                    let mut tool_results = IndexMap::default();
141                    for tool_result in msg.tool_results {
142                        let name = tool_names_by_id
143                            .remove(&tool_result.tool_use_id)
144                            .unwrap_or_else(|| SharedString::from("unknown"));
145                        tool_results.insert(
146                            tool_result.tool_use_id.clone(),
147                            language_model::LanguageModelToolResult {
148                                tool_use_id: tool_result.tool_use_id,
149                                tool_name: name.into(),
150                                is_error: tool_result.is_error,
151                                content: tool_result.content,
152                                output: tool_result.output,
153                            },
154                        );
155                    }
156
157                    crate::Message::Agent(AgentMessage {
158                        content,
159                        tool_results,
160                    })
161                }
162                language_model::Role::System => {
163                    // Skip system messages as they're not supported in the new format
164                    continue;
165                }
166            };
167
168            messages.push(message);
169        }
170
171        Ok(Self {
172            title: thread.summary,
173            messages,
174            updated_at: thread.updated_at,
175            summary: thread.detailed_summary_state,
176            initial_project_snapshot: thread.initial_project_snapshot,
177            cumulative_token_usage: thread.cumulative_token_usage,
178            request_token_usage: thread.request_token_usage,
179            model: thread.model,
180            completion_mode: thread.completion_mode,
181            profile: thread.profile,
182        })
183    }
184}
185
186pub static ZED_STATELESS: std::sync::LazyLock<bool> =
187    std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
188
189#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
190pub enum DataType {
191    #[serde(rename = "json")]
192    Json,
193    #[serde(rename = "zstd")]
194    Zstd,
195}
196
197impl Bind for DataType {
198    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
199        let value = match self {
200            DataType::Json => "json",
201            DataType::Zstd => "zstd",
202        };
203        value.bind(statement, start_index)
204    }
205}
206
207impl Column for DataType {
208    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
209        let (value, next_index) = String::column(statement, start_index)?;
210        let data_type = match value.as_str() {
211            "json" => DataType::Json,
212            "zstd" => DataType::Zstd,
213            _ => anyhow::bail!("Unknown data type: {}", value),
214        };
215        Ok((data_type, next_index))
216    }
217}
218
219pub(crate) struct ThreadsDatabase {
220    executor: BackgroundExecutor,
221    connection: Arc<Mutex<Connection>>,
222}
223
224struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
225
226impl Global for GlobalThreadsDatabase {}
227
228impl ThreadsDatabase {
229    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
230        if cx.has_global::<GlobalThreadsDatabase>() {
231            return cx.global::<GlobalThreadsDatabase>().0.clone();
232        }
233        let executor = cx.background_executor().clone();
234        let task = executor
235            .spawn({
236                let executor = executor.clone();
237                async move {
238                    match ThreadsDatabase::new(executor) {
239                        Ok(db) => Ok(Arc::new(db)),
240                        Err(err) => Err(Arc::new(err)),
241                    }
242                }
243            })
244            .shared();
245
246        cx.set_global(GlobalThreadsDatabase(task.clone()));
247        task
248    }
249
250    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
251        let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
252            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
253        } else {
254            let threads_dir = paths::data_dir().join("threads");
255            std::fs::create_dir_all(&threads_dir)?;
256            let sqlite_path = threads_dir.join("threads.db");
257            Connection::open_file(&sqlite_path.to_string_lossy())
258        };
259
260        connection.exec(indoc! {"
261            CREATE TABLE IF NOT EXISTS threads (
262                id TEXT PRIMARY KEY,
263                summary TEXT NOT NULL,
264                updated_at TEXT NOT NULL,
265                data_type TEXT NOT NULL,
266                data BLOB NOT NULL
267            )
268        "})?()
269        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
270
271        let db = Self {
272            executor: executor.clone(),
273            connection: Arc::new(Mutex::new(connection)),
274        };
275
276        Ok(db)
277    }
278
279    fn save_thread_sync(
280        connection: &Arc<Mutex<Connection>>,
281        id: acp::SessionId,
282        thread: DbThread,
283    ) -> Result<()> {
284        const COMPRESSION_LEVEL: i32 = 3;
285
286        #[derive(Serialize)]
287        struct SerializedThread {
288            #[serde(flatten)]
289            thread: DbThread,
290            version: &'static str,
291        }
292
293        let title = thread.title.to_string();
294        let updated_at = thread.updated_at.to_rfc3339();
295        let json_data = serde_json::to_string(&SerializedThread {
296            thread,
297            version: DbThread::VERSION,
298        })?;
299
300        let connection = connection.lock();
301
302        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
303        let data_type = DataType::Zstd;
304        let data = compressed;
305
306        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
307            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
308        "})?;
309
310        insert((id.0.clone(), title, updated_at, data_type, data))?;
311
312        Ok(())
313    }
314
315    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
316        let connection = self.connection.clone();
317
318        self.executor.spawn(async move {
319            let connection = connection.lock();
320
321            let mut select =
322                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
323                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
324            "})?;
325
326            let rows = select(())?;
327            let mut threads = Vec::new();
328
329            for (id, summary, updated_at) in rows {
330                threads.push(DbThreadMetadata {
331                    id: acp::SessionId(id),
332                    title: summary.into(),
333                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
334                });
335            }
336
337            Ok(threads)
338        })
339    }
340
341    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
342        let connection = self.connection.clone();
343
344        self.executor.spawn(async move {
345            let connection = connection.lock();
346            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
347                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
348            "})?;
349
350            let rows = select(id.0)?;
351            if let Some((data_type, data)) = rows.into_iter().next() {
352                let json_data = match data_type {
353                    DataType::Zstd => {
354                        let decompressed = zstd::decode_all(&data[..])?;
355                        String::from_utf8(decompressed)?
356                    }
357                    DataType::Json => String::from_utf8(data)?,
358                };
359                let thread = DbThread::from_json(json_data.as_bytes())?;
360                Ok(Some(thread))
361            } else {
362                Ok(None)
363            }
364        })
365    }
366
367    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
368        let connection = self.connection.clone();
369
370        self.executor
371            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
372    }
373
374    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
375        let connection = self.connection.clone();
376
377        self.executor.spawn(async move {
378            let connection = connection.lock();
379
380            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
381                DELETE FROM threads WHERE id = ?
382            "})?;
383
384            delete(id.0)?;
385
386            Ok(())
387        })
388    }
389}
390
391#[cfg(test)]
392mod tests {
393
394    use super::*;
395    use agent::MessageSegment;
396    use agent::context::LoadedContext;
397    use client::Client;
398    use fs::FakeFs;
399    use gpui::AppContext;
400    use gpui::TestAppContext;
401    use http_client::FakeHttpClient;
402    use language_model::Role;
403    use project::Project;
404    use settings::SettingsStore;
405
406    fn init_test(cx: &mut TestAppContext) {
407        env_logger::try_init().ok();
408        cx.update(|cx| {
409            let settings_store = SettingsStore::test(cx);
410            cx.set_global(settings_store);
411            Project::init_settings(cx);
412            language::init(cx);
413
414            let http_client = FakeHttpClient::with_404_response();
415            let clock = Arc::new(clock::FakeSystemClock::new());
416            let client = Client::new(clock, http_client, cx);
417            agent::init(cx);
418            agent_settings::init(cx);
419            language_model::init(client.clone(), cx);
420        });
421    }
422
423    #[gpui::test]
424    async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
425        init_test(cx);
426        let fs = FakeFs::new(cx.executor());
427        let project = Project::test(fs, [], cx).await;
428
429        // Save a thread using the old agent.
430        let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
431        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
432        thread.update(cx, |thread, cx| {
433            thread.insert_message(
434                Role::User,
435                vec![MessageSegment::Text("Hey!".into())],
436                LoadedContext::default(),
437                vec![],
438                false,
439                cx,
440            );
441            thread.insert_message(
442                Role::Assistant,
443                vec![MessageSegment::Text("How're you doing?".into())],
444                LoadedContext::default(),
445                vec![],
446                false,
447                cx,
448            )
449        });
450        thread_store
451            .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
452            .await
453            .unwrap();
454
455        // Open that same thread using the new agent.
456        let db = cx.update(ThreadsDatabase::connect).await.unwrap();
457        let threads = db.list_threads().await.unwrap();
458        assert_eq!(threads.len(), 1);
459        let thread = db
460            .load_thread(threads[0].id.clone())
461            .await
462            .unwrap()
463            .unwrap();
464        assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
465        assert_eq!(
466            thread.messages[1].to_markdown(),
467            "## Assistant\n\nHow're you doing?\n"
468        );
469    }
470}