db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use agent::thread_store;
  3use agent_client_protocol as acp;
  4use agent_settings::{AgentProfileId, CompletionMode};
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20
 21pub type DbMessage = crate::Message;
 22pub type DbSummary = agent::thread::DetailedSummaryState;
 23pub type DbLanguageModel = thread_store::SerializedLanguageModel;
 24
 25#[derive(Debug, Clone, Serialize, Deserialize)]
 26pub struct DbThreadMetadata {
 27    pub id: acp::SessionId,
 28    #[serde(alias = "summary")]
 29    pub title: SharedString,
 30    pub updated_at: DateTime<Utc>,
 31}
 32
 33#[derive(Debug, Serialize, Deserialize)]
 34pub struct DbThread {
 35    pub title: SharedString,
 36    pub messages: Vec<DbMessage>,
 37    pub updated_at: DateTime<Utc>,
 38    #[serde(default)]
 39    pub summary: DbSummary,
 40    #[serde(default)]
 41    pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
 42    #[serde(default)]
 43    pub cumulative_token_usage: language_model::TokenUsage,
 44    #[serde(default)]
 45    pub request_token_usage: Vec<language_model::TokenUsage>,
 46    #[serde(default)]
 47    pub model: Option<DbLanguageModel>,
 48    #[serde(default)]
 49    pub completion_mode: Option<CompletionMode>,
 50    #[serde(default)]
 51    pub profile: Option<AgentProfileId>,
 52}
 53
 54impl DbThread {
 55    pub const VERSION: &'static str = "0.3.0";
 56
 57    pub fn from_json(json: &[u8]) -> Result<Self> {
 58        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 59        match saved_thread_json.get("version") {
 60            Some(serde_json::Value::String(version)) => match version.as_str() {
 61                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 62                _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 63            },
 64            _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
 65        }
 66    }
 67
 68    fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
 69        let mut messages = Vec::new();
 70        for msg in thread.messages {
 71            let message = match msg.role {
 72                language_model::Role::User => {
 73                    let mut content = Vec::new();
 74
 75                    // Convert segments to content
 76                    for segment in msg.segments {
 77                        match segment {
 78                            thread_store::SerializedMessageSegment::Text { text } => {
 79                                content.push(UserMessageContent::Text(text));
 80                            }
 81                            thread_store::SerializedMessageSegment::Thinking { text, .. } => {
 82                                // User messages don't have thinking segments, but handle gracefully
 83                                content.push(UserMessageContent::Text(text));
 84                            }
 85                            thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
 86                                // User messages don't have redacted thinking, skip.
 87                            }
 88                        }
 89                    }
 90
 91                    // If no content was added, add context as text if available
 92                    if content.is_empty() && !msg.context.is_empty() {
 93                        content.push(UserMessageContent::Text(msg.context));
 94                    }
 95
 96                    crate::Message::User(UserMessage {
 97                        // MessageId from old format can't be meaningfully converted, so generate a new one
 98                        id: acp_thread::UserMessageId::new(),
 99                        content,
100                    })
101                }
102                language_model::Role::Assistant => {
103                    let mut content = Vec::new();
104
105                    // Convert segments to content
106                    for segment in msg.segments {
107                        match segment {
108                            thread_store::SerializedMessageSegment::Text { text } => {
109                                content.push(AgentMessageContent::Text(text));
110                            }
111                            thread_store::SerializedMessageSegment::Thinking {
112                                text,
113                                signature,
114                            } => {
115                                content.push(AgentMessageContent::Thinking { text, signature });
116                            }
117                            thread_store::SerializedMessageSegment::RedactedThinking { data } => {
118                                content.push(AgentMessageContent::RedactedThinking(data));
119                            }
120                        }
121                    }
122
123                    // Convert tool uses
124                    let mut tool_names_by_id = HashMap::default();
125                    for tool_use in msg.tool_uses {
126                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
127                        content.push(AgentMessageContent::ToolUse(
128                            language_model::LanguageModelToolUse {
129                                id: tool_use.id,
130                                name: tool_use.name.into(),
131                                raw_input: serde_json::to_string(&tool_use.input)
132                                    .unwrap_or_default(),
133                                input: tool_use.input,
134                                is_input_complete: true,
135                            },
136                        ));
137                    }
138
139                    // Convert tool results
140                    let mut tool_results = IndexMap::default();
141                    for tool_result in msg.tool_results {
142                        let name = tool_names_by_id
143                            .remove(&tool_result.tool_use_id)
144                            .unwrap_or_else(|| SharedString::from("unknown"));
145                        tool_results.insert(
146                            tool_result.tool_use_id.clone(),
147                            language_model::LanguageModelToolResult {
148                                tool_use_id: tool_result.tool_use_id,
149                                tool_name: name.into(),
150                                is_error: tool_result.is_error,
151                                content: tool_result.content,
152                                output: tool_result.output,
153                            },
154                        );
155                    }
156
157                    crate::Message::Agent(AgentMessage {
158                        content,
159                        tool_results,
160                    })
161                }
162                language_model::Role::System => {
163                    // Skip system messages as they're not supported in the new format
164                    continue;
165                }
166            };
167
168            messages.push(message);
169        }
170
171        Ok(Self {
172            title: thread.summary,
173            messages,
174            updated_at: thread.updated_at,
175            summary: thread.detailed_summary_state,
176            initial_project_snapshot: thread.initial_project_snapshot,
177            cumulative_token_usage: thread.cumulative_token_usage,
178            request_token_usage: thread.request_token_usage,
179            model: thread.model,
180            completion_mode: thread.completion_mode,
181            profile: thread.profile,
182        })
183    }
184}
185
186pub static ZED_STATELESS: std::sync::LazyLock<bool> =
187    std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
188
189#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
190pub enum DataType {
191    #[serde(rename = "json")]
192    Json,
193    #[serde(rename = "zstd")]
194    Zstd,
195}
196
197impl Bind for DataType {
198    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
199        let value = match self {
200            DataType::Json => "json",
201            DataType::Zstd => "zstd",
202        };
203        value.bind(statement, start_index)
204    }
205}
206
207impl Column for DataType {
208    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
209        let (value, next_index) = String::column(statement, start_index)?;
210        let data_type = match value.as_str() {
211            "json" => DataType::Json,
212            "zstd" => DataType::Zstd,
213            _ => anyhow::bail!("Unknown data type: {}", value),
214        };
215        Ok((data_type, next_index))
216    }
217}
218
219pub(crate) struct ThreadsDatabase {
220    executor: BackgroundExecutor,
221    connection: Arc<Mutex<Connection>>,
222}
223
224struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
225
226impl Global for GlobalThreadsDatabase {}
227
228impl ThreadsDatabase {
229    fn connection(&self) -> Arc<Mutex<Connection>> {
230        self.connection.clone()
231    }
232
233    const COMPRESSION_LEVEL: i32 = 3;
234}
235
236impl ThreadsDatabase {
237    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
238        if cx.has_global::<GlobalThreadsDatabase>() {
239            return cx.global::<GlobalThreadsDatabase>().0.clone();
240        }
241        let executor = cx.background_executor().clone();
242        let task = executor
243            .spawn({
244                let executor = executor.clone();
245                async move {
246                    match ThreadsDatabase::new(executor) {
247                        Ok(db) => Ok(Arc::new(db)),
248                        Err(err) => Err(Arc::new(err)),
249                    }
250                }
251            })
252            .shared();
253
254        cx.set_global(GlobalThreadsDatabase(task.clone()));
255        task
256    }
257
258    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
259        let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
260            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
261        } else {
262            let threads_dir = paths::data_dir().join("threads");
263            std::fs::create_dir_all(&threads_dir)?;
264            let sqlite_path = threads_dir.join("threads.db");
265            Connection::open_file(&sqlite_path.to_string_lossy())
266        };
267
268        connection.exec(indoc! {"
269            CREATE TABLE IF NOT EXISTS threads (
270                id TEXT PRIMARY KEY,
271                summary TEXT NOT NULL,
272                updated_at TEXT NOT NULL,
273                data_type TEXT NOT NULL,
274                data BLOB NOT NULL
275            )
276        "})?()
277        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
278
279        let db = Self {
280            executor: executor.clone(),
281            connection: Arc::new(Mutex::new(connection)),
282        };
283
284        Ok(db)
285    }
286
287    fn save_thread_sync(
288        connection: &Arc<Mutex<Connection>>,
289        id: acp::SessionId,
290        thread: DbThread,
291    ) -> Result<()> {
292        let json_data = serde_json::to_string(&thread)?;
293        let title = thread.title.to_string();
294        let updated_at = thread.updated_at.to_rfc3339();
295
296        let connection = connection.lock();
297
298        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
299        let data_type = DataType::Zstd;
300        let data = compressed;
301
302        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
303            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
304        "})?;
305
306        insert((id.0, title, updated_at, data_type, data))?;
307
308        Ok(())
309    }
310
311    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
312        let connection = self.connection.clone();
313
314        self.executor.spawn(async move {
315            let connection = connection.lock();
316            let mut select =
317                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
318                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
319            "})?;
320
321            let rows = select(())?;
322            let mut threads = Vec::new();
323
324            for (id, summary, updated_at) in rows {
325                threads.push(DbThreadMetadata {
326                    id: acp::SessionId(id),
327                    title: summary.into(),
328                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
329                });
330            }
331
332            Ok(threads)
333        })
334    }
335
336    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
337        let connection = self.connection.clone();
338
339        self.executor.spawn(async move {
340            let connection = connection.lock();
341            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
342                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
343            "})?;
344
345            let rows = select(id.0)?;
346            if let Some((data_type, data)) = rows.into_iter().next() {
347                let json_data = match data_type {
348                    DataType::Zstd => {
349                        let decompressed = zstd::decode_all(&data[..])?;
350                        String::from_utf8(decompressed)?
351                    }
352                    DataType::Json => String::from_utf8(data)?,
353                };
354
355                let thread = DbThread::from_json(json_data.as_bytes())?;
356                Ok(Some(thread))
357            } else {
358                Ok(None)
359            }
360        })
361    }
362
363    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
364        let connection = self.connection.clone();
365
366        self.executor
367            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
368    }
369
370    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
371        let connection = self.connection.clone();
372
373        self.executor.spawn(async move {
374            let connection = connection.lock();
375
376            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
377                DELETE FROM threads WHERE id = ?
378            "})?;
379
380            delete(id.0)?;
381
382            Ok(())
383        })
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use crate::NativeAgent;
390    use crate::Templates;
391
392    use super::*;
393    use agent::MessageSegment;
394    use agent::context::LoadedContext;
395    use client::Client;
396    use fs::FakeFs;
397    use gpui::AppContext;
398    use gpui::TestAppContext;
399    use http_client::FakeHttpClient;
400    use language_model::Role;
401    use project::Project;
402    use settings::SettingsStore;
403
404    fn init_test(cx: &mut TestAppContext) {
405        env_logger::try_init().ok();
406        cx.update(|cx| {
407            let settings_store = SettingsStore::test(cx);
408            cx.set_global(settings_store);
409            Project::init_settings(cx);
410            language::init(cx);
411
412            let http_client = FakeHttpClient::with_404_response();
413            let clock = Arc::new(clock::FakeSystemClock::new());
414            let client = Client::new(clock, http_client, cx);
415            agent::init(cx);
416            agent_settings::init(cx);
417            language_model::init(client.clone(), cx);
418        });
419    }
420
421    #[gpui::test]
422    async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
423        init_test(cx);
424        let fs = FakeFs::new(cx.executor());
425        let project = Project::test(fs, [], cx).await;
426
427        // Save a thread using the old agent.
428        {
429            let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
430            let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
431            thread.update(cx, |thread, cx| {
432                thread.insert_message(
433                    Role::User,
434                    vec![MessageSegment::Text("Hey!".into())],
435                    LoadedContext::default(),
436                    vec![],
437                    false,
438                    cx,
439                );
440                thread.insert_message(
441                    Role::Assistant,
442                    vec![MessageSegment::Text("How're you doing?".into())],
443                    LoadedContext::default(),
444                    vec![],
445                    false,
446                    cx,
447                )
448            });
449            thread_store
450                .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
451                .await
452                .unwrap();
453        }
454
455        let db = cx.update(|cx| ThreadsDatabase::connect(cx)).await.unwrap();
456        let threads = db.list_threads().await.unwrap();
457        assert_eq!(threads.len(), 1);
458        let thread = db
459            .load_thread(threads[0].id.clone())
460            .await
461            .unwrap()
462            .unwrap();
463        assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
464        assert_eq!(
465            thread.messages[1].to_markdown(),
466            "## Assistant\n\nHow're you doing?\n"
467        );
468    }
469}