db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::AgentProfileId;
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35pub struct DbThread {
 36    pub title: SharedString,
 37    pub messages: Vec<DbMessage>,
 38    pub updated_at: DateTime<Utc>,
 39    #[serde(default)]
 40    pub detailed_summary: Option<SharedString>,
 41    #[serde(default)]
 42    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 43    #[serde(default)]
 44    pub cumulative_token_usage: language_model::TokenUsage,
 45    #[serde(default)]
 46    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 47    #[serde(default)]
 48    pub model: Option<DbLanguageModel>,
 49    #[serde(default)]
 50    pub profile: Option<AgentProfileId>,
 51    #[serde(default)]
 52    pub imported: bool,
 53}
 54
 55#[derive(Debug, Clone, Serialize, Deserialize)]
 56pub struct SharedThread {
 57    pub title: SharedString,
 58    pub messages: Vec<DbMessage>,
 59    pub updated_at: DateTime<Utc>,
 60    #[serde(default)]
 61    pub model: Option<DbLanguageModel>,
 62    pub version: String,
 63}
 64
 65impl SharedThread {
 66    pub const VERSION: &'static str = "1.0.0";
 67
 68    pub fn from_db_thread(thread: &DbThread) -> Self {
 69        Self {
 70            title: thread.title.clone(),
 71            messages: thread.messages.clone(),
 72            updated_at: thread.updated_at,
 73            model: thread.model.clone(),
 74            version: Self::VERSION.to_string(),
 75        }
 76    }
 77
 78    pub fn to_db_thread(self) -> DbThread {
 79        DbThread {
 80            title: format!("🔗 {}", self.title).into(),
 81            messages: self.messages,
 82            updated_at: self.updated_at,
 83            detailed_summary: None,
 84            initial_project_snapshot: None,
 85            cumulative_token_usage: Default::default(),
 86            request_token_usage: Default::default(),
 87            model: self.model,
 88            profile: None,
 89            imported: true,
 90        }
 91    }
 92
 93    pub fn to_bytes(&self) -> Result<Vec<u8>> {
 94        const COMPRESSION_LEVEL: i32 = 3;
 95        let json = serde_json::to_vec(self)?;
 96        let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
 97        Ok(compressed)
 98    }
 99
100    pub fn from_bytes(data: &[u8]) -> Result<Self> {
101        let decompressed = zstd::decode_all(data)?;
102        Ok(serde_json::from_slice(&decompressed)?)
103    }
104}
105
106impl DbThread {
107    pub const VERSION: &'static str = "0.3.0";
108
109    pub fn from_json(json: &[u8]) -> Result<Self> {
110        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
111        match saved_thread_json.get("version") {
112            Some(serde_json::Value::String(version)) => match version.as_str() {
113                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
114                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
115                    json,
116                )?),
117            },
118            _ => {
119                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
120            }
121        }
122    }
123
124    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
125        let mut messages = Vec::new();
126        let mut request_token_usage = HashMap::default();
127
128        let mut last_user_message_id = None;
129        for (ix, msg) in thread.messages.into_iter().enumerate() {
130            let message = match msg.role {
131                language_model::Role::User => {
132                    let mut content = Vec::new();
133
134                    // Convert segments to content
135                    for segment in msg.segments {
136                        match segment {
137                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
138                                content.push(UserMessageContent::Text(text));
139                            }
140                            crate::legacy_thread::SerializedMessageSegment::Thinking {
141                                text,
142                                ..
143                            } => {
144                                // User messages don't have thinking segments, but handle gracefully
145                                content.push(UserMessageContent::Text(text));
146                            }
147                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
148                                ..
149                            } => {
150                                // User messages don't have redacted thinking, skip.
151                            }
152                        }
153                    }
154
155                    // If no content was added, add context as text if available
156                    if content.is_empty() && !msg.context.is_empty() {
157                        content.push(UserMessageContent::Text(msg.context));
158                    }
159
160                    let id = UserMessageId::new();
161                    last_user_message_id = Some(id.clone());
162
163                    crate::Message::User(UserMessage {
164                        // MessageId from old format can't be meaningfully converted, so generate a new one
165                        id,
166                        content,
167                    })
168                }
169                language_model::Role::Assistant => {
170                    let mut content = Vec::new();
171
172                    // Convert segments to content
173                    for segment in msg.segments {
174                        match segment {
175                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
176                                content.push(AgentMessageContent::Text(text));
177                            }
178                            crate::legacy_thread::SerializedMessageSegment::Thinking {
179                                text,
180                                signature,
181                            } => {
182                                content.push(AgentMessageContent::Thinking { text, signature });
183                            }
184                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
185                                data,
186                            } => {
187                                content.push(AgentMessageContent::RedactedThinking(data));
188                            }
189                        }
190                    }
191
192                    // Convert tool uses
193                    let mut tool_names_by_id = HashMap::default();
194                    for tool_use in msg.tool_uses {
195                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
196                        content.push(AgentMessageContent::ToolUse(
197                            language_model::LanguageModelToolUse {
198                                id: tool_use.id,
199                                name: tool_use.name.into(),
200                                raw_input: serde_json::to_string(&tool_use.input)
201                                    .unwrap_or_default(),
202                                input: tool_use.input,
203                                is_input_complete: true,
204                                thought_signature: None,
205                            },
206                        ));
207                    }
208
209                    // Convert tool results
210                    let mut tool_results = IndexMap::default();
211                    for tool_result in msg.tool_results {
212                        let name = tool_names_by_id
213                            .remove(&tool_result.tool_use_id)
214                            .unwrap_or_else(|| SharedString::from("unknown"));
215                        tool_results.insert(
216                            tool_result.tool_use_id.clone(),
217                            language_model::LanguageModelToolResult {
218                                tool_use_id: tool_result.tool_use_id,
219                                tool_name: name.into(),
220                                is_error: tool_result.is_error,
221                                content: tool_result.content,
222                                output: tool_result.output,
223                            },
224                        );
225                    }
226
227                    if let Some(last_user_message_id) = &last_user_message_id
228                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
229                    {
230                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
231                    }
232
233                    crate::Message::Agent(AgentMessage {
234                        content,
235                        tool_results,
236                        reasoning_details: None,
237                    })
238                }
239                language_model::Role::System => {
240                    // Skip system messages as they're not supported in the new format
241                    continue;
242                }
243            };
244
245            messages.push(message);
246        }
247
248        Ok(Self {
249            title: thread.summary,
250            messages,
251            updated_at: thread.updated_at,
252            detailed_summary: match thread.detailed_summary_state {
253                crate::legacy_thread::DetailedSummaryState::NotGenerated
254                | crate::legacy_thread::DetailedSummaryState::Generating => None,
255                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
256            },
257            initial_project_snapshot: thread.initial_project_snapshot,
258            cumulative_token_usage: thread.cumulative_token_usage,
259            request_token_usage,
260            model: thread.model,
261            profile: thread.profile,
262            imported: false,
263        })
264    }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
268pub enum DataType {
269    #[serde(rename = "json")]
270    Json,
271    #[serde(rename = "zstd")]
272    Zstd,
273}
274
275impl Bind for DataType {
276    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
277        let value = match self {
278            DataType::Json => "json",
279            DataType::Zstd => "zstd",
280        };
281        value.bind(statement, start_index)
282    }
283}
284
285impl Column for DataType {
286    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
287        let (value, next_index) = String::column(statement, start_index)?;
288        let data_type = match value.as_str() {
289            "json" => DataType::Json,
290            "zstd" => DataType::Zstd,
291            _ => anyhow::bail!("Unknown data type: {}", value),
292        };
293        Ok((data_type, next_index))
294    }
295}
296
297pub(crate) struct ThreadsDatabase {
298    executor: BackgroundExecutor,
299    connection: Arc<Mutex<Connection>>,
300}
301
302struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
303
304impl Global for GlobalThreadsDatabase {}
305
306impl ThreadsDatabase {
307    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
308        if cx.has_global::<GlobalThreadsDatabase>() {
309            return cx.global::<GlobalThreadsDatabase>().0.clone();
310        }
311        let executor = cx.background_executor().clone();
312        let task = executor
313            .spawn({
314                let executor = executor.clone();
315                async move {
316                    match ThreadsDatabase::new(executor) {
317                        Ok(db) => Ok(Arc::new(db)),
318                        Err(err) => Err(Arc::new(err)),
319                    }
320                }
321            })
322            .shared();
323
324        cx.set_global(GlobalThreadsDatabase(task.clone()));
325        task
326    }
327
328    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
329        let connection = if *ZED_STATELESS {
330            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
331        } else if cfg!(any(feature = "test-support", test)) {
332            // rust stores the name of the test on the current thread.
333            // We use this to automatically create a database that will
334            // be shared within the test (for the test_retrieve_old_thread)
335            // but not with concurrent tests.
336            let thread = std::thread::current();
337            let test_name = thread.name();
338            Connection::open_memory(Some(&format!(
339                "THREAD_FALLBACK_{}",
340                test_name.unwrap_or_default()
341            )))
342        } else {
343            let threads_dir = paths::data_dir().join("threads");
344            std::fs::create_dir_all(&threads_dir)?;
345            let sqlite_path = threads_dir.join("threads.db");
346            Connection::open_file(&sqlite_path.to_string_lossy())
347        };
348
349        connection.exec(indoc! {"
350            CREATE TABLE IF NOT EXISTS threads (
351                id TEXT PRIMARY KEY,
352                summary TEXT NOT NULL,
353                updated_at TEXT NOT NULL,
354                data_type TEXT NOT NULL,
355                data BLOB NOT NULL
356            )
357        "})?()
358        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
359
360        let db = Self {
361            executor,
362            connection: Arc::new(Mutex::new(connection)),
363        };
364
365        Ok(db)
366    }
367
368    fn save_thread_sync(
369        connection: &Arc<Mutex<Connection>>,
370        id: acp::SessionId,
371        thread: DbThread,
372    ) -> Result<()> {
373        const COMPRESSION_LEVEL: i32 = 3;
374
375        #[derive(Serialize)]
376        struct SerializedThread {
377            #[serde(flatten)]
378            thread: DbThread,
379            version: &'static str,
380        }
381
382        let title = thread.title.to_string();
383        let updated_at = thread.updated_at.to_rfc3339();
384        let json_data = serde_json::to_string(&SerializedThread {
385            thread,
386            version: DbThread::VERSION,
387        })?;
388
389        let connection = connection.lock();
390
391        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
392        let data_type = DataType::Zstd;
393        let data = compressed;
394
395        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
396            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
397        "})?;
398
399        insert((id.0, title, updated_at, data_type, data))?;
400
401        Ok(())
402    }
403
404    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
405        let connection = self.connection.clone();
406
407        self.executor.spawn(async move {
408            let connection = connection.lock();
409
410            let mut select =
411                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
412                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
413            "})?;
414
415            let rows = select(())?;
416            let mut threads = Vec::new();
417
418            for (id, summary, updated_at) in rows {
419                threads.push(DbThreadMetadata {
420                    id: acp::SessionId::new(id),
421                    title: summary.into(),
422                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
423                });
424            }
425
426            Ok(threads)
427        })
428    }
429
430    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
431        let connection = self.connection.clone();
432
433        self.executor.spawn(async move {
434            let connection = connection.lock();
435            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
436                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
437            "})?;
438
439            let rows = select(id.0)?;
440            if let Some((data_type, data)) = rows.into_iter().next() {
441                let json_data = match data_type {
442                    DataType::Zstd => {
443                        let decompressed = zstd::decode_all(&data[..])?;
444                        String::from_utf8(decompressed)?
445                    }
446                    DataType::Json => String::from_utf8(data)?,
447                };
448                let thread = DbThread::from_json(json_data.as_bytes())?;
449                Ok(Some(thread))
450            } else {
451                Ok(None)
452            }
453        })
454    }
455
456    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
457        let connection = self.connection.clone();
458
459        self.executor
460            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
461    }
462
463    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
464        let connection = self.connection.clone();
465
466        self.executor.spawn(async move {
467            let connection = connection.lock();
468
469            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
470                DELETE FROM threads WHERE id = ?
471            "})?;
472
473            delete(id.0)?;
474
475            Ok(())
476        })
477    }
478
479    pub fn delete_threads(&self) -> Task<Result<()>> {
480        let connection = self.connection.clone();
481
482        self.executor.spawn(async move {
483            let connection = connection.lock();
484
485            let mut delete = connection.exec_bound::<()>(indoc! {"
486                DELETE FROM threads
487            "})?;
488
489            delete(())?;
490
491            Ok(())
492        })
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use chrono::{DateTime, TimeZone, Utc};
500    use collections::HashMap;
501    use gpui::TestAppContext;
502    use std::sync::Arc;
503
504    #[test]
505    fn test_shared_thread_roundtrip() {
506        let original = SharedThread {
507            title: "Test Thread".into(),
508            messages: vec![],
509            updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
510            model: None,
511            version: SharedThread::VERSION.to_string(),
512        };
513
514        let bytes = original.to_bytes().expect("Failed to serialize");
515        let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
516
517        assert_eq!(restored.title, original.title);
518        assert_eq!(restored.version, original.version);
519        assert_eq!(restored.updated_at, original.updated_at);
520    }
521
522    #[test]
523    fn test_imported_flag_defaults_to_false() {
524        // Simulate deserializing a thread without the imported field (backwards compatibility).
525        let json = r#"{
526            "title": "Old Thread",
527            "messages": [],
528            "updated_at": "2024-01-01T00:00:00Z"
529        }"#;
530
531        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
532
533        assert!(
534            !db_thread.imported,
535            "Legacy threads without imported field should default to false"
536        );
537    }
538
539    fn session_id(value: &str) -> acp::SessionId {
540        acp::SessionId::new(Arc::<str>::from(value))
541    }
542
543    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
544        DbThread {
545            title: title.to_string().into(),
546            messages: Vec::new(),
547            updated_at,
548            detailed_summary: None,
549            initial_project_snapshot: None,
550            cumulative_token_usage: Default::default(),
551            request_token_usage: HashMap::default(),
552            model: None,
553            profile: None,
554            imported: false,
555        }
556    }
557
558    #[gpui::test]
559    async fn test_list_threads_orders_by_updated_at(cx: &mut TestAppContext) {
560        let database = ThreadsDatabase::new(cx.executor()).unwrap();
561
562        let older_id = session_id("thread-a");
563        let newer_id = session_id("thread-b");
564
565        let older_thread = make_thread(
566            "Thread A",
567            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
568        );
569        let newer_thread = make_thread(
570            "Thread B",
571            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
572        );
573
574        database
575            .save_thread(older_id.clone(), older_thread)
576            .await
577            .unwrap();
578        database
579            .save_thread(newer_id.clone(), newer_thread)
580            .await
581            .unwrap();
582
583        let entries = database.list_threads().await.unwrap();
584        assert_eq!(entries.len(), 2);
585        assert_eq!(entries[0].id, newer_id);
586        assert_eq!(entries[1].id, older_id);
587    }
588
589    #[gpui::test]
590    async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
591        let database = ThreadsDatabase::new(cx.executor()).unwrap();
592
593        let thread_id = session_id("thread-a");
594        let original_thread = make_thread(
595            "Thread A",
596            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
597        );
598        let updated_thread = make_thread(
599            "Thread B",
600            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
601        );
602
603        database
604            .save_thread(thread_id.clone(), original_thread)
605            .await
606            .unwrap();
607        database
608            .save_thread(thread_id.clone(), updated_thread)
609            .await
610            .unwrap();
611
612        let entries = database.list_threads().await.unwrap();
613        assert_eq!(entries.len(), 1);
614        assert_eq!(entries[0].id, thread_id);
615        assert_eq!(entries[0].title.as_ref(), "Thread B");
616        assert_eq!(
617            entries[0].updated_at,
618            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
619        );
620    }
621}