db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::{AcpThreadMetadata, AgentServerName};
  3use agent::thread_store;
  4use agent_client_protocol as acp;
  5use agent_settings::{AgentProfileId, CompletionMode};
  6use anyhow::{Result, anyhow};
  7use chrono::{DateTime, Utc};
  8use collections::{HashMap, IndexMap};
  9use futures::{FutureExt, future::Shared};
 10use gpui::{BackgroundExecutor, Global, Task};
 11use indoc::indoc;
 12use parking_lot::Mutex;
 13use serde::{Deserialize, Serialize};
 14use sqlez::{
 15    bindable::{Bind, Column},
 16    connection::Connection,
 17    statement::Statement,
 18};
 19use std::sync::Arc;
 20use ui::{App, SharedString};
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = agent::thread::DetailedSummaryState;
 24pub type DbLanguageModel = thread_store::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34impl DbThreadMetadata {
 35    pub fn to_acp(self, agent: AgentServerName) -> AcpThreadMetadata {
 36        AcpThreadMetadata {
 37            agent,
 38            id: self.id,
 39            title: self.title,
 40            updated_at: self.updated_at,
 41        }
 42    }
 43}
 44
 45#[derive(Debug, Serialize, Deserialize)]
 46pub struct DbThread {
 47    pub title: SharedString,
 48    pub messages: Vec<DbMessage>,
 49    pub updated_at: DateTime<Utc>,
 50    #[serde(default)]
 51    pub summary: DbSummary,
 52    #[serde(default)]
 53    pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
 54    #[serde(default)]
 55    pub cumulative_token_usage: language_model::TokenUsage,
 56    #[serde(default)]
 57    pub request_token_usage: Vec<language_model::TokenUsage>,
 58    #[serde(default)]
 59    pub model: Option<DbLanguageModel>,
 60    #[serde(default)]
 61    pub completion_mode: Option<CompletionMode>,
 62    #[serde(default)]
 63    pub profile: Option<AgentProfileId>,
 64}
 65
 66impl DbThread {
 67    pub const VERSION: &'static str = "0.3.0";
 68
 69    pub fn from_json(json: &[u8]) -> Result<Self> {
 70        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 71        match saved_thread_json.get("version") {
 72            Some(serde_json::Value::String(version)) => match version.as_str() {
 73                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 74                _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 75            },
 76            _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 77        }
 78    }
 79
 80    fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
 81        let mut messages = Vec::new();
 82        for msg in thread.messages {
 83            let message = match msg.role {
 84                language_model::Role::User => {
 85                    let mut content = Vec::new();
 86
 87                    // Convert segments to content
 88                    for segment in msg.segments {
 89                        match segment {
 90                            thread_store::SerializedMessageSegment::Text { text } => {
 91                                content.push(UserMessageContent::Text(text));
 92                            }
 93                            thread_store::SerializedMessageSegment::Thinking { text, .. } => {
 94                                // User messages don't have thinking segments, but handle gracefully
 95                                content.push(UserMessageContent::Text(text));
 96                            }
 97                            thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
 98                                // User messages don't have redacted thinking, skip.
 99                            }
100                        }
101                    }
102
103                    // If no content was added, add context as text if available
104                    if content.is_empty() && !msg.context.is_empty() {
105                        content.push(UserMessageContent::Text(msg.context));
106                    }
107
108                    crate::Message::User(UserMessage {
109                        // MessageId from old format can't be meaningfully converted, so generate a new one
110                        id: acp_thread::UserMessageId::new(),
111                        content,
112                    })
113                }
114                language_model::Role::Assistant => {
115                    let mut content = Vec::new();
116
117                    // Convert segments to content
118                    for segment in msg.segments {
119                        match segment {
120                            thread_store::SerializedMessageSegment::Text { text } => {
121                                content.push(AgentMessageContent::Text(text));
122                            }
123                            thread_store::SerializedMessageSegment::Thinking {
124                                text,
125                                signature,
126                            } => {
127                                content.push(AgentMessageContent::Thinking { text, signature });
128                            }
129                            thread_store::SerializedMessageSegment::RedactedThinking { data } => {
130                                content.push(AgentMessageContent::RedactedThinking(data));
131                            }
132                        }
133                    }
134
135                    // Convert tool uses
136                    let mut tool_names_by_id = HashMap::default();
137                    for tool_use in msg.tool_uses {
138                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
139                        content.push(AgentMessageContent::ToolUse(
140                            language_model::LanguageModelToolUse {
141                                id: tool_use.id,
142                                name: tool_use.name.into(),
143                                raw_input: serde_json::to_string(&tool_use.input)
144                                    .unwrap_or_default(),
145                                input: tool_use.input,
146                                is_input_complete: true,
147                            },
148                        ));
149                    }
150
151                    // Convert tool results
152                    let mut tool_results = IndexMap::default();
153                    for tool_result in msg.tool_results {
154                        let name = tool_names_by_id
155                            .remove(&tool_result.tool_use_id)
156                            .unwrap_or_else(|| SharedString::from("unknown"));
157                        tool_results.insert(
158                            tool_result.tool_use_id.clone(),
159                            language_model::LanguageModelToolResult {
160                                tool_use_id: tool_result.tool_use_id,
161                                tool_name: name.into(),
162                                is_error: tool_result.is_error,
163                                content: tool_result.content,
164                                output: tool_result.output,
165                            },
166                        );
167                    }
168
169                    crate::Message::Agent(AgentMessage {
170                        content,
171                        tool_results,
172                    })
173                }
174                language_model::Role::System => {
175                    // Skip system messages as they're not supported in the new format
176                    continue;
177                }
178            };
179
180            messages.push(message);
181        }
182
183        Ok(Self {
184            title: thread.summary,
185            messages,
186            updated_at: thread.updated_at,
187            summary: thread.detailed_summary_state,
188            initial_project_snapshot: thread.initial_project_snapshot,
189            cumulative_token_usage: thread.cumulative_token_usage,
190            request_token_usage: thread.request_token_usage,
191            model: thread.model,
192            completion_mode: thread.completion_mode,
193            profile: thread.profile,
194        })
195    }
196}
197
198pub static ZED_STATELESS: std::sync::LazyLock<bool> =
199    std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
200
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202pub enum DataType {
203    #[serde(rename = "json")]
204    Json,
205    #[serde(rename = "zstd")]
206    Zstd,
207}
208
209impl Bind for DataType {
210    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
211        let value = match self {
212            DataType::Json => "json",
213            DataType::Zstd => "zstd",
214        };
215        value.bind(statement, start_index)
216    }
217}
218
219impl Column for DataType {
220    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
221        let (value, next_index) = String::column(statement, start_index)?;
222        let data_type = match value.as_str() {
223            "json" => DataType::Json,
224            "zstd" => DataType::Zstd,
225            _ => anyhow::bail!("Unknown data type: {}", value),
226        };
227        Ok((data_type, next_index))
228    }
229}
230
231pub(crate) struct ThreadsDatabase {
232    executor: BackgroundExecutor,
233    connection: Arc<Mutex<Connection>>,
234}
235
236struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
237
238impl Global for GlobalThreadsDatabase {}
239
240impl ThreadsDatabase {
241    fn connection(&self) -> Arc<Mutex<Connection>> {
242        self.connection.clone()
243    }
244
245    const COMPRESSION_LEVEL: i32 = 3;
246}
247
248impl ThreadsDatabase {
249    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
250        if cx.has_global::<GlobalThreadsDatabase>() {
251            return cx.global::<GlobalThreadsDatabase>().0.clone();
252        }
253        let executor = cx.background_executor().clone();
254        let task = executor
255            .spawn({
256                let executor = executor.clone();
257                async move {
258                    match ThreadsDatabase::new(executor) {
259                        Ok(db) => Ok(Arc::new(db)),
260                        Err(err) => Err(Arc::new(err)),
261                    }
262                }
263            })
264            .shared();
265
266        cx.set_global(GlobalThreadsDatabase(task.clone()));
267        task
268    }
269
270    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
271        let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
272            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
273        } else {
274            let threads_dir = paths::data_dir().join("threads");
275            std::fs::create_dir_all(&threads_dir)?;
276            let sqlite_path = threads_dir.join("threads.db");
277            Connection::open_file(&sqlite_path.to_string_lossy())
278        };
279
280        connection.exec(indoc! {"
281            CREATE TABLE IF NOT EXISTS threads (
282                id TEXT PRIMARY KEY,
283                summary TEXT NOT NULL,
284                updated_at TEXT NOT NULL,
285                data_type TEXT NOT NULL,
286                data BLOB NOT NULL
287            )
288        "})?()
289        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
290
291        let db = Self {
292            executor: executor.clone(),
293            connection: Arc::new(Mutex::new(connection)),
294        };
295
296        Ok(db)
297    }
298
299    fn save_thread_sync(
300        connection: &Arc<Mutex<Connection>>,
301        id: acp::SessionId,
302        thread: DbThread,
303    ) -> Result<DbThreadMetadata> {
304        let json_data = serde_json::to_string(&thread)?;
305        let title = thread.title.to_string();
306        let updated_at = thread.updated_at.to_rfc3339();
307
308        let connection = connection.lock();
309
310        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
311        let data_type = DataType::Zstd;
312        let data = compressed;
313
314        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
315            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
316        "})?;
317
318        insert((id.0.clone(), title, updated_at, data_type, data))?;
319
320        Ok(DbThreadMetadata {
321            id,
322            title: thread.title,
323            updated_at: thread.updated_at,
324        })
325    }
326
327    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
328        let connection = self.connection.clone();
329
330        self.executor.spawn(async move {
331            let connection = connection.lock();
332            let mut select =
333                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
334                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
335            "})?;
336
337            let rows = select(())?;
338            let mut threads = Vec::new();
339
340            for (id, summary, updated_at) in rows {
341                threads.push(DbThreadMetadata {
342                    id: acp::SessionId(id),
343                    title: summary.into(),
344                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
345                });
346            }
347
348            Ok(threads)
349        })
350    }
351
352    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
353        let connection = self.connection.clone();
354
355        self.executor.spawn(async move {
356            let connection = connection.lock();
357            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
358                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
359            "})?;
360
361            let rows = select(id.0)?;
362            if let Some((data_type, data)) = rows.into_iter().next() {
363                let json_data = match data_type {
364                    DataType::Zstd => {
365                        let decompressed = zstd::decode_all(&data[..])?;
366                        String::from_utf8(decompressed)?
367                    }
368                    DataType::Json => String::from_utf8(data)?,
369                };
370                dbg!(&json_data);
371
372                let thread = dbg!(DbThread::from_json(json_data.as_bytes()))?;
373                Ok(Some(thread))
374            } else {
375                Ok(None)
376            }
377        })
378    }
379
380    pub fn save_thread(
381        &self,
382        id: acp::SessionId,
383        thread: DbThread,
384    ) -> Task<Result<DbThreadMetadata>> {
385        let connection = self.connection.clone();
386
387        self.executor
388            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
389    }
390
391    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
392        let connection = self.connection.clone();
393
394        self.executor.spawn(async move {
395            let connection = connection.lock();
396
397            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
398                DELETE FROM threads WHERE id = ?
399            "})?;
400
401            delete(id.0)?;
402
403            Ok(())
404        })
405    }
406}
407
408#[cfg(test)]
409mod tests {
410
411    use super::*;
412    use agent::MessageSegment;
413    use agent::context::LoadedContext;
414    use client::Client;
415    use fs::FakeFs;
416    use gpui::AppContext;
417    use gpui::TestAppContext;
418    use http_client::FakeHttpClient;
419    use language_model::Role;
420    use project::Project;
421    use settings::SettingsStore;
422
423    fn init_test(cx: &mut TestAppContext) {
424        env_logger::try_init().ok();
425        cx.update(|cx| {
426            let settings_store = SettingsStore::test(cx);
427            cx.set_global(settings_store);
428            Project::init_settings(cx);
429            language::init(cx);
430
431            let http_client = FakeHttpClient::with_404_response();
432            let clock = Arc::new(clock::FakeSystemClock::new());
433            let client = Client::new(clock, http_client, cx);
434            agent::init(cx);
435            agent_settings::init(cx);
436            language_model::init(client.clone(), cx);
437        });
438    }
439
440    #[gpui::test]
441    async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
442        init_test(cx);
443        let fs = FakeFs::new(cx.executor());
444        let project = Project::test(fs, [], cx).await;
445
446        // Save a thread using the old agent.
447        {
448            let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
449            let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
450            thread.update(cx, |thread, cx| {
451                thread.insert_message(
452                    Role::User,
453                    vec![MessageSegment::Text("Hey!".into())],
454                    LoadedContext::default(),
455                    vec![],
456                    false,
457                    cx,
458                );
459                thread.insert_message(
460                    Role::Assistant,
461                    vec![MessageSegment::Text("How're you doing?".into())],
462                    LoadedContext::default(),
463                    vec![],
464                    false,
465                    cx,
466                )
467            });
468            thread_store
469                .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
470                .await
471                .unwrap();
472        }
473
474        let db = cx.update(|cx| ThreadsDatabase::connect(cx)).await.unwrap();
475        let threads = db.list_threads().await.unwrap();
476        assert_eq!(threads.len(), 1);
477        let thread = db
478            .load_thread(threads[0].id.clone())
479            .await
480            .unwrap()
481            .unwrap();
482        assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
483        assert_eq!(
484            thread.messages[1].to_markdown(),
485            "## Assistant\n\nHow're you doing?\n"
486        );
487    }
488}