db.rs

  1use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
  2use acp_thread::UserMessageId;
  3use agent_client_protocol as acp;
  4use agent_settings::{AgentProfileId, CompletionMode};
  5use anyhow::{Result, anyhow};
  6use chrono::{DateTime, Utc};
  7use collections::{HashMap, IndexMap};
  8use futures::{FutureExt, future::Shared};
  9use gpui::{BackgroundExecutor, Global, Task};
 10use indoc::indoc;
 11use parking_lot::Mutex;
 12use serde::{Deserialize, Serialize};
 13use sqlez::{
 14    bindable::{Bind, Column},
 15    connection::Connection,
 16    statement::Statement,
 17};
 18use std::sync::Arc;
 19use ui::{App, SharedString};
 20use zed_env_vars::ZED_STATELESS;
 21
 22pub type DbMessage = crate::Message;
 23pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
 24pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
 25
 26#[derive(Debug, Clone, Serialize, Deserialize)]
 27pub struct DbThreadMetadata {
 28    pub id: acp::SessionId,
 29    #[serde(alias = "summary")]
 30    pub title: SharedString,
 31    pub updated_at: DateTime<Utc>,
 32}
 33
 34#[derive(Debug, Serialize, Deserialize)]
 35pub struct DbThread {
 36    pub title: SharedString,
 37    pub messages: Vec<DbMessage>,
 38    pub updated_at: DateTime<Utc>,
 39    #[serde(default)]
 40    pub detailed_summary: Option<SharedString>,
 41    #[serde(default)]
 42    pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
 43    #[serde(default)]
 44    pub cumulative_token_usage: language_model::TokenUsage,
 45    #[serde(default)]
 46    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
 47    #[serde(default)]
 48    pub model: Option<DbLanguageModel>,
 49    #[serde(default)]
 50    pub completion_mode: Option<CompletionMode>,
 51    #[serde(default)]
 52    pub profile: Option<AgentProfileId>,
 53}
 54
 55impl DbThread {
 56    pub const VERSION: &'static str = "0.3.0";
 57
 58    pub fn from_json(json: &[u8]) -> Result<Self> {
 59        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 60        match saved_thread_json.get("version") {
 61            Some(serde_json::Value::String(version)) => match version.as_str() {
 62                Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
 63                _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
 64                    json,
 65                )?),
 66            },
 67            _ => {
 68                Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
 69            }
 70        }
 71    }
 72
 73    fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
 74        let mut messages = Vec::new();
 75        let mut request_token_usage = HashMap::default();
 76
 77        let mut last_user_message_id = None;
 78        for (ix, msg) in thread.messages.into_iter().enumerate() {
 79            let message = match msg.role {
 80                language_model::Role::User => {
 81                    let mut content = Vec::new();
 82
 83                    // Convert segments to content
 84                    for segment in msg.segments {
 85                        match segment {
 86                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
 87                                content.push(UserMessageContent::Text(text));
 88                            }
 89                            crate::legacy_thread::SerializedMessageSegment::Thinking {
 90                                text,
 91                                ..
 92                            } => {
 93                                // User messages don't have thinking segments, but handle gracefully
 94                                content.push(UserMessageContent::Text(text));
 95                            }
 96                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
 97                                ..
 98                            } => {
 99                                // User messages don't have redacted thinking, skip.
100                            }
101                        }
102                    }
103
104                    // If no content was added, add context as text if available
105                    if content.is_empty() && !msg.context.is_empty() {
106                        content.push(UserMessageContent::Text(msg.context));
107                    }
108
109                    let id = UserMessageId::new();
110                    last_user_message_id = Some(id.clone());
111
112                    crate::Message::User(UserMessage {
113                        // MessageId from old format can't be meaningfully converted, so generate a new one
114                        id,
115                        content,
116                    })
117                }
118                language_model::Role::Assistant => {
119                    let mut content = Vec::new();
120
121                    // Convert segments to content
122                    for segment in msg.segments {
123                        match segment {
124                            crate::legacy_thread::SerializedMessageSegment::Text { text } => {
125                                content.push(AgentMessageContent::Text(text));
126                            }
127                            crate::legacy_thread::SerializedMessageSegment::Thinking {
128                                text,
129                                signature,
130                            } => {
131                                content.push(AgentMessageContent::Thinking { text, signature });
132                            }
133                            crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
134                                data,
135                            } => {
136                                content.push(AgentMessageContent::RedactedThinking(data));
137                            }
138                        }
139                    }
140
141                    // Convert tool uses
142                    let mut tool_names_by_id = HashMap::default();
143                    for tool_use in msg.tool_uses {
144                        tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
145                        content.push(AgentMessageContent::ToolUse(
146                            language_model::LanguageModelToolUse {
147                                id: tool_use.id,
148                                name: tool_use.name.into(),
149                                raw_input: serde_json::to_string(&tool_use.input)
150                                    .unwrap_or_default(),
151                                input: tool_use.input,
152                                is_input_complete: true,
153                            },
154                        ));
155                    }
156
157                    // Convert tool results
158                    let mut tool_results = IndexMap::default();
159                    for tool_result in msg.tool_results {
160                        let name = tool_names_by_id
161                            .remove(&tool_result.tool_use_id)
162                            .unwrap_or_else(|| SharedString::from("unknown"));
163                        tool_results.insert(
164                            tool_result.tool_use_id.clone(),
165                            language_model::LanguageModelToolResult {
166                                tool_use_id: tool_result.tool_use_id,
167                                tool_name: name.into(),
168                                is_error: tool_result.is_error,
169                                content: tool_result.content,
170                                output: tool_result.output,
171                            },
172                        );
173                    }
174
175                    if let Some(last_user_message_id) = &last_user_message_id
176                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
177                    {
178                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
179                    }
180
181                    crate::Message::Agent(AgentMessage {
182                        content,
183                        tool_results,
184                    })
185                }
186                language_model::Role::System => {
187                    // Skip system messages as they're not supported in the new format
188                    continue;
189                }
190            };
191
192            messages.push(message);
193        }
194
195        Ok(Self {
196            title: thread.summary,
197            messages,
198            updated_at: thread.updated_at,
199            detailed_summary: match thread.detailed_summary_state {
200                crate::legacy_thread::DetailedSummaryState::NotGenerated
201                | crate::legacy_thread::DetailedSummaryState::Generating => None,
202                crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
203            },
204            initial_project_snapshot: thread.initial_project_snapshot,
205            cumulative_token_usage: thread.cumulative_token_usage,
206            request_token_usage,
207            model: thread.model,
208            completion_mode: thread.completion_mode,
209            profile: thread.profile,
210        })
211    }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
215pub enum DataType {
216    #[serde(rename = "json")]
217    Json,
218    #[serde(rename = "zstd")]
219    Zstd,
220}
221
222impl Bind for DataType {
223    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
224        let value = match self {
225            DataType::Json => "json",
226            DataType::Zstd => "zstd",
227        };
228        value.bind(statement, start_index)
229    }
230}
231
232impl Column for DataType {
233    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
234        let (value, next_index) = String::column(statement, start_index)?;
235        let data_type = match value.as_str() {
236            "json" => DataType::Json,
237            "zstd" => DataType::Zstd,
238            _ => anyhow::bail!("Unknown data type: {}", value),
239        };
240        Ok((data_type, next_index))
241    }
242}
243
244pub(crate) struct ThreadsDatabase {
245    executor: BackgroundExecutor,
246    connection: Arc<Mutex<Connection>>,
247}
248
249struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
250
251impl Global for GlobalThreadsDatabase {}
252
253impl ThreadsDatabase {
254    pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
255        if cx.has_global::<GlobalThreadsDatabase>() {
256            return cx.global::<GlobalThreadsDatabase>().0.clone();
257        }
258        let executor = cx.background_executor().clone();
259        let task = executor
260            .spawn({
261                let executor = executor.clone();
262                async move {
263                    match ThreadsDatabase::new(executor) {
264                        Ok(db) => Ok(Arc::new(db)),
265                        Err(err) => Err(Arc::new(err)),
266                    }
267                }
268            })
269            .shared();
270
271        cx.set_global(GlobalThreadsDatabase(task.clone()));
272        task
273    }
274
275    pub fn new(executor: BackgroundExecutor) -> Result<Self> {
276        let connection = if *ZED_STATELESS {
277            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
278        } else if cfg!(any(feature = "test-support", test)) {
279            // rust stores the name of the test on the current thread.
280            // We use this to automatically create a database that will
281            // be shared within the test (for the test_retrieve_old_thread)
282            // but not with concurrent tests.
283            let thread = std::thread::current();
284            let test_name = thread.name();
285            Connection::open_memory(Some(&format!(
286                "THREAD_FALLBACK_{}",
287                test_name.unwrap_or_default()
288            )))
289        } else {
290            let threads_dir = paths::data_dir().join("threads");
291            std::fs::create_dir_all(&threads_dir)?;
292            let sqlite_path = threads_dir.join("threads.db");
293            Connection::open_file(&sqlite_path.to_string_lossy())
294        };
295
296        connection.exec(indoc! {"
297            CREATE TABLE IF NOT EXISTS threads (
298                id TEXT PRIMARY KEY,
299                summary TEXT NOT NULL,
300                updated_at TEXT NOT NULL,
301                data_type TEXT NOT NULL,
302                data BLOB NOT NULL
303            )
304        "})?()
305        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
306
307        let db = Self {
308            executor,
309            connection: Arc::new(Mutex::new(connection)),
310        };
311
312        Ok(db)
313    }
314
315    fn save_thread_sync(
316        connection: &Arc<Mutex<Connection>>,
317        id: acp::SessionId,
318        thread: DbThread,
319    ) -> Result<()> {
320        const COMPRESSION_LEVEL: i32 = 3;
321
322        #[derive(Serialize)]
323        struct SerializedThread {
324            #[serde(flatten)]
325            thread: DbThread,
326            version: &'static str,
327        }
328
329        let title = thread.title.to_string();
330        let updated_at = thread.updated_at.to_rfc3339();
331        let json_data = serde_json::to_string(&SerializedThread {
332            thread,
333            version: DbThread::VERSION,
334        })?;
335
336        let connection = connection.lock();
337
338        let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
339        let data_type = DataType::Zstd;
340        let data = compressed;
341
342        let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
343            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
344        "})?;
345
346        insert((id.0, title, updated_at, data_type, data))?;
347
348        Ok(())
349    }
350
351    pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
352        let connection = self.connection.clone();
353
354        self.executor.spawn(async move {
355            let connection = connection.lock();
356
357            let mut select =
358                connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
359                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
360            "})?;
361
362            let rows = select(())?;
363            let mut threads = Vec::new();
364
365            for (id, summary, updated_at) in rows {
366                threads.push(DbThreadMetadata {
367                    id: acp::SessionId(id),
368                    title: summary.into(),
369                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
370                });
371            }
372
373            Ok(threads)
374        })
375    }
376
377    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
378        let connection = self.connection.clone();
379
380        self.executor.spawn(async move {
381            let connection = connection.lock();
382            let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
383                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
384            "})?;
385
386            let rows = select(id.0)?;
387            if let Some((data_type, data)) = rows.into_iter().next() {
388                let json_data = match data_type {
389                    DataType::Zstd => {
390                        let decompressed = zstd::decode_all(&data[..])?;
391                        String::from_utf8(decompressed)?
392                    }
393                    DataType::Json => String::from_utf8(data)?,
394                };
395                let thread = DbThread::from_json(json_data.as_bytes())?;
396                Ok(Some(thread))
397            } else {
398                Ok(None)
399            }
400        })
401    }
402
403    pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
404        let connection = self.connection.clone();
405
406        self.executor
407            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
408    }
409
410    pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
411        let connection = self.connection.clone();
412
413        self.executor.spawn(async move {
414            let connection = connection.lock();
415
416            let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
417                DELETE FROM threads WHERE id = ?
418            "})?;
419
420            delete(id.0)?;
421
422            Ok(())
423        })
424    }
425}