1use std::borrow::Cow;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use assistant_settings::{AgentProfile, AssistantSettings};
7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
8use chrono::{DateTime, Utc};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
12use futures::FutureExt as _;
13use futures::future::{self, BoxFuture, Shared};
14use gpui::{
15 App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*,
16};
17use heed::Database;
18use heed::types::SerdeBincode;
19use language_model::{LanguageModelToolUseId, Role, TokenUsage};
20use project::Project;
21use prompt_store::PromptBuilder;
22use serde::{Deserialize, Serialize};
23use settings::Settings as _;
24use util::ResultExt as _;
25
26use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
27
28pub fn init(cx: &mut App) {
29 ThreadsDatabase::init(cx);
30}
31
32pub struct ThreadStore {
33 project: Entity<Project>,
34 tools: Arc<ToolWorkingSet>,
35 prompt_builder: Arc<PromptBuilder>,
36 context_server_manager: Entity<ContextServerManager>,
37 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
38 threads: Vec<SerializedThreadMetadata>,
39}
40
41impl ThreadStore {
42 pub fn new(
43 project: Entity<Project>,
44 tools: Arc<ToolWorkingSet>,
45 prompt_builder: Arc<PromptBuilder>,
46 cx: &mut App,
47 ) -> Result<Entity<Self>> {
48 let this = cx.new(|cx| {
49 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
50 let context_server_manager = cx.new(|cx| {
51 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
52 });
53
54 let this = Self {
55 project,
56 tools,
57 prompt_builder,
58 context_server_manager,
59 context_server_tool_ids: HashMap::default(),
60 threads: Vec::new(),
61 };
62 this.load_default_profile(cx);
63 this.register_context_server_handlers(cx);
64 this.reload(cx).detach_and_log_err(cx);
65
66 this
67 });
68
69 Ok(this)
70 }
71
72 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
73 self.context_server_manager.clone()
74 }
75
76 pub fn tools(&self) -> Arc<ToolWorkingSet> {
77 self.tools.clone()
78 }
79
80 /// Returns the number of threads.
81 pub fn thread_count(&self) -> usize {
82 self.threads.len()
83 }
84
85 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
86 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
87 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
88 threads
89 }
90
91 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
92 self.threads().into_iter().take(limit).collect()
93 }
94
95 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
96 cx.new(|cx| {
97 Thread::new(
98 self.project.clone(),
99 self.tools.clone(),
100 self.prompt_builder.clone(),
101 cx,
102 )
103 })
104 }
105
106 pub fn open_thread(
107 &self,
108 id: &ThreadId,
109 cx: &mut Context<Self>,
110 ) -> Task<Result<Entity<Thread>>> {
111 let id = id.clone();
112 let database_future = ThreadsDatabase::global_future(cx);
113 cx.spawn(async move |this, cx| {
114 let database = database_future.await.map_err(|err| anyhow!(err))?;
115 let thread = database
116 .try_find_thread(id.clone())
117 .await?
118 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
119
120 let thread = this.update(cx, |this, cx| {
121 cx.new(|cx| {
122 Thread::deserialize(
123 id.clone(),
124 thread,
125 this.project.clone(),
126 this.tools.clone(),
127 this.prompt_builder.clone(),
128 cx,
129 )
130 })
131 })?;
132
133 let (system_prompt_context, load_error) = thread
134 .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
135 .await;
136 thread.update(cx, |thread, cx| {
137 thread.set_system_prompt_context(system_prompt_context);
138 if let Some(load_error) = load_error {
139 cx.emit(ThreadEvent::ShowError(load_error));
140 }
141 })?;
142
143 Ok(thread)
144 })
145 }
146
147 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
148 let (metadata, serialized_thread) =
149 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
150
151 let database_future = ThreadsDatabase::global_future(cx);
152 cx.spawn(async move |this, cx| {
153 let serialized_thread = serialized_thread.await?;
154 let database = database_future.await.map_err(|err| anyhow!(err))?;
155 database.save_thread(metadata, serialized_thread).await?;
156
157 this.update(cx, |this, cx| this.reload(cx))?.await
158 })
159 }
160
161 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
162 let id = id.clone();
163 let database_future = ThreadsDatabase::global_future(cx);
164 cx.spawn(async move |this, cx| {
165 let database = database_future.await.map_err(|err| anyhow!(err))?;
166 database.delete_thread(id.clone()).await?;
167
168 this.update(cx, |this, _cx| {
169 this.threads.retain(|thread| thread.id != id)
170 })
171 })
172 }
173
174 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
175 let database_future = ThreadsDatabase::global_future(cx);
176 cx.spawn(async move |this, cx| {
177 let threads = database_future
178 .await
179 .map_err(|err| anyhow!(err))?
180 .list_threads()
181 .await?;
182
183 this.update(cx, |this, cx| {
184 this.threads = threads;
185 cx.notify();
186 })
187 })
188 }
189
190 fn load_default_profile(&self, cx: &Context<Self>) {
191 let assistant_settings = AssistantSettings::get_global(cx);
192
193 self.load_profile_by_id(&assistant_settings.default_profile, cx);
194 }
195
196 pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
197 let assistant_settings = AssistantSettings::get_global(cx);
198
199 if let Some(profile) = assistant_settings.profiles.get(profile_id) {
200 self.load_profile(profile);
201 }
202 }
203
204 pub fn load_profile(&self, profile: &AgentProfile) {
205 self.tools.disable_all_tools();
206 self.tools.enable(
207 ToolSource::Native,
208 &profile
209 .tools
210 .iter()
211 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
212 .collect::<Vec<_>>(),
213 );
214
215 for (context_server_id, preset) in &profile.context_servers {
216 self.tools.enable(
217 ToolSource::ContextServer {
218 id: context_server_id.clone().into(),
219 },
220 &preset
221 .tools
222 .iter()
223 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
224 .collect::<Vec<_>>(),
225 )
226 }
227 }
228
229 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
230 cx.subscribe(
231 &self.context_server_manager.clone(),
232 Self::handle_context_server_event,
233 )
234 .detach();
235 }
236
237 fn handle_context_server_event(
238 &mut self,
239 context_server_manager: Entity<ContextServerManager>,
240 event: &context_server::manager::Event,
241 cx: &mut Context<Self>,
242 ) {
243 let tool_working_set = self.tools.clone();
244 match event {
245 context_server::manager::Event::ServerStarted { server_id } => {
246 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
247 let context_server_manager = context_server_manager.clone();
248 cx.spawn({
249 let server = server.clone();
250 let server_id = server_id.clone();
251 async move |this, cx| {
252 let Some(protocol) = server.client() else {
253 return;
254 };
255
256 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
257 if let Some(tools) = protocol.list_tools().await.log_err() {
258 let tool_ids = tools
259 .tools
260 .into_iter()
261 .map(|tool| {
262 log::info!(
263 "registering context server tool: {:?}",
264 tool.name
265 );
266 tool_working_set.insert(Arc::new(
267 ContextServerTool::new(
268 context_server_manager.clone(),
269 server.id(),
270 tool,
271 ),
272 ))
273 })
274 .collect::<Vec<_>>();
275
276 this.update(cx, |this, _cx| {
277 this.context_server_tool_ids.insert(server_id, tool_ids);
278 })
279 .log_err();
280 }
281 }
282 }
283 })
284 .detach();
285 }
286 }
287 context_server::manager::Event::ServerStopped { server_id } => {
288 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
289 tool_working_set.remove(&tool_ids);
290 }
291 }
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct SerializedThreadMetadata {
298 pub id: ThreadId,
299 pub summary: SharedString,
300 pub updated_at: DateTime<Utc>,
301}
302
303#[derive(Serialize, Deserialize)]
304pub struct SerializedThread {
305 pub version: String,
306 pub summary: SharedString,
307 pub updated_at: DateTime<Utc>,
308 pub messages: Vec<SerializedMessage>,
309 #[serde(default)]
310 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
311 #[serde(default)]
312 pub cumulative_token_usage: TokenUsage,
313}
314
315impl SerializedThread {
316 pub const VERSION: &'static str = "0.1.0";
317
318 pub fn from_json(json: &[u8]) -> Result<Self> {
319 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
320 match saved_thread_json.get("version") {
321 Some(serde_json::Value::String(version)) => match version.as_str() {
322 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
323 saved_thread_json,
324 )?),
325 _ => Err(anyhow!(
326 "unrecognized serialized thread version: {}",
327 version
328 )),
329 },
330 None => {
331 let saved_thread =
332 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
333 Ok(saved_thread.upgrade())
334 }
335 version => Err(anyhow!(
336 "unrecognized serialized thread version: {:?}",
337 version
338 )),
339 }
340 }
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344pub struct SerializedMessage {
345 pub id: MessageId,
346 pub role: Role,
347 #[serde(default)]
348 pub segments: Vec<SerializedMessageSegment>,
349 #[serde(default)]
350 pub tool_uses: Vec<SerializedToolUse>,
351 #[serde(default)]
352 pub tool_results: Vec<SerializedToolResult>,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(tag = "type")]
357pub enum SerializedMessageSegment {
358 #[serde(rename = "text")]
359 Text { text: String },
360 #[serde(rename = "thinking")]
361 Thinking { text: String },
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365pub struct SerializedToolUse {
366 pub id: LanguageModelToolUseId,
367 pub name: SharedString,
368 pub input: serde_json::Value,
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372pub struct SerializedToolResult {
373 pub tool_use_id: LanguageModelToolUseId,
374 pub is_error: bool,
375 pub content: Arc<str>,
376}
377
378#[derive(Serialize, Deserialize)]
379struct LegacySerializedThread {
380 pub summary: SharedString,
381 pub updated_at: DateTime<Utc>,
382 pub messages: Vec<LegacySerializedMessage>,
383 #[serde(default)]
384 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
385}
386
387impl LegacySerializedThread {
388 pub fn upgrade(self) -> SerializedThread {
389 SerializedThread {
390 version: SerializedThread::VERSION.to_string(),
391 summary: self.summary,
392 updated_at: self.updated_at,
393 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
394 initial_project_snapshot: self.initial_project_snapshot,
395 cumulative_token_usage: TokenUsage::default(),
396 }
397 }
398}
399
400#[derive(Debug, Serialize, Deserialize)]
401struct LegacySerializedMessage {
402 pub id: MessageId,
403 pub role: Role,
404 pub text: String,
405 #[serde(default)]
406 pub tool_uses: Vec<SerializedToolUse>,
407 #[serde(default)]
408 pub tool_results: Vec<SerializedToolResult>,
409}
410
411impl LegacySerializedMessage {
412 fn upgrade(self) -> SerializedMessage {
413 SerializedMessage {
414 id: self.id,
415 role: self.role,
416 segments: vec![SerializedMessageSegment::Text { text: self.text }],
417 tool_uses: self.tool_uses,
418 tool_results: self.tool_results,
419 }
420 }
421}
422
423struct GlobalThreadsDatabase(
424 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
425);
426
427impl Global for GlobalThreadsDatabase {}
428
429pub(crate) struct ThreadsDatabase {
430 executor: BackgroundExecutor,
431 env: heed::Env,
432 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
433}
434
435impl heed::BytesEncode<'_> for SerializedThread {
436 type EItem = SerializedThread;
437
438 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
439 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
440 }
441}
442
443impl<'a> heed::BytesDecode<'a> for SerializedThread {
444 type DItem = SerializedThread;
445
446 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
447 // We implement this type manually because we want to call `SerializedThread::from_json`,
448 // instead of the Deserialize trait implementation for `SerializedThread`.
449 SerializedThread::from_json(bytes).map_err(Into::into)
450 }
451}
452
453impl ThreadsDatabase {
454 fn global_future(
455 cx: &mut App,
456 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
457 GlobalThreadsDatabase::global(cx).0.clone()
458 }
459
460 fn init(cx: &mut App) {
461 let executor = cx.background_executor().clone();
462 let database_future = executor
463 .spawn({
464 let executor = executor.clone();
465 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
466 async move { ThreadsDatabase::new(database_path, executor) }
467 })
468 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
469 .boxed()
470 .shared();
471
472 cx.set_global(GlobalThreadsDatabase(database_future));
473 }
474
475 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
476 std::fs::create_dir_all(&path)?;
477
478 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
479 let env = unsafe {
480 heed::EnvOpenOptions::new()
481 .map_size(ONE_GB_IN_BYTES)
482 .max_dbs(1)
483 .open(path)?
484 };
485
486 let mut txn = env.write_txn()?;
487 let threads = env.create_database(&mut txn, Some("threads"))?;
488 txn.commit()?;
489
490 Ok(Self {
491 executor,
492 env,
493 threads,
494 })
495 }
496
497 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
498 let env = self.env.clone();
499 let threads = self.threads;
500
501 self.executor.spawn(async move {
502 let txn = env.read_txn()?;
503 let mut iter = threads.iter(&txn)?;
504 let mut threads = Vec::new();
505 while let Some((key, value)) = iter.next().transpose()? {
506 threads.push(SerializedThreadMetadata {
507 id: key,
508 summary: value.summary,
509 updated_at: value.updated_at,
510 });
511 }
512
513 Ok(threads)
514 })
515 }
516
517 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
518 let env = self.env.clone();
519 let threads = self.threads;
520
521 self.executor.spawn(async move {
522 let txn = env.read_txn()?;
523 let thread = threads.get(&txn, &id)?;
524 Ok(thread)
525 })
526 }
527
528 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
529 let env = self.env.clone();
530 let threads = self.threads;
531
532 self.executor.spawn(async move {
533 let mut txn = env.write_txn()?;
534 threads.put(&mut txn, &id, &thread)?;
535 txn.commit()?;
536 Ok(())
537 })
538 }
539
540 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
541 let env = self.env.clone();
542 let threads = self.threads;
543
544 self.executor.spawn(async move {
545 let mut txn = env.write_txn()?;
546 threads.delete(&mut txn, &id)?;
547 txn.commit()?;
548 Ok(())
549 })
550 }
551}