1use std::borrow::Cow;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use assistant_settings::{AgentProfile, AssistantSettings};
7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
8use chrono::{DateTime, Utc};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
12use futures::FutureExt as _;
13use futures::future::{self, BoxFuture, Shared};
14use gpui::{
15 App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
16 prelude::*,
17};
18use heed::Database;
19use heed::types::SerdeBincode;
20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
21use project::Project;
22use prompt_store::PromptBuilder;
23use serde::{Deserialize, Serialize};
24use settings::{Settings as _, SettingsStore};
25use util::ResultExt as _;
26
27use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
28
29pub fn init(cx: &mut App) {
30 ThreadsDatabase::init(cx);
31}
32
33pub struct ThreadStore {
34 project: Entity<Project>,
35 tools: Arc<ToolWorkingSet>,
36 prompt_builder: Arc<PromptBuilder>,
37 context_server_manager: Entity<ContextServerManager>,
38 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
39 threads: Vec<SerializedThreadMetadata>,
40 _subscriptions: Vec<Subscription>,
41}
42
43impl ThreadStore {
44 pub fn new(
45 project: Entity<Project>,
46 tools: Arc<ToolWorkingSet>,
47 prompt_builder: Arc<PromptBuilder>,
48 cx: &mut App,
49 ) -> Result<Entity<Self>> {
50 let this = cx.new(|cx| {
51 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
52 let context_server_manager = cx.new(|cx| {
53 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
54 });
55 let settings_subscription =
56 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
57 this.load_default_profile(cx);
58 });
59
60 let this = Self {
61 project,
62 tools,
63 prompt_builder,
64 context_server_manager,
65 context_server_tool_ids: HashMap::default(),
66 threads: Vec::new(),
67 _subscriptions: vec![settings_subscription],
68 };
69 this.load_default_profile(cx);
70 this.register_context_server_handlers(cx);
71 this.reload(cx).detach_and_log_err(cx);
72
73 this
74 });
75
76 Ok(this)
77 }
78
79 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
80 self.context_server_manager.clone()
81 }
82
83 pub fn tools(&self) -> Arc<ToolWorkingSet> {
84 self.tools.clone()
85 }
86
87 /// Returns the number of threads.
88 pub fn thread_count(&self) -> usize {
89 self.threads.len()
90 }
91
92 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
93 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
94 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
95 threads
96 }
97
98 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
99 self.threads().into_iter().take(limit).collect()
100 }
101
102 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
103 cx.new(|cx| {
104 Thread::new(
105 self.project.clone(),
106 self.tools.clone(),
107 self.prompt_builder.clone(),
108 cx,
109 )
110 })
111 }
112
113 pub fn open_thread(
114 &self,
115 id: &ThreadId,
116 cx: &mut Context<Self>,
117 ) -> Task<Result<Entity<Thread>>> {
118 let id = id.clone();
119 let database_future = ThreadsDatabase::global_future(cx);
120 cx.spawn(async move |this, cx| {
121 let database = database_future.await.map_err(|err| anyhow!(err))?;
122 let thread = database
123 .try_find_thread(id.clone())
124 .await?
125 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
126
127 let thread = this.update(cx, |this, cx| {
128 cx.new(|cx| {
129 Thread::deserialize(
130 id.clone(),
131 thread,
132 this.project.clone(),
133 this.tools.clone(),
134 this.prompt_builder.clone(),
135 cx,
136 )
137 })
138 })?;
139
140 let (system_prompt_context, load_error) = thread
141 .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
142 .await;
143 thread.update(cx, |thread, cx| {
144 thread.set_system_prompt_context(system_prompt_context);
145 if let Some(load_error) = load_error {
146 cx.emit(ThreadEvent::ShowError(load_error));
147 }
148 })?;
149
150 Ok(thread)
151 })
152 }
153
154 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
155 let (metadata, serialized_thread) =
156 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
157
158 let database_future = ThreadsDatabase::global_future(cx);
159 cx.spawn(async move |this, cx| {
160 let serialized_thread = serialized_thread.await?;
161 let database = database_future.await.map_err(|err| anyhow!(err))?;
162 database.save_thread(metadata, serialized_thread).await?;
163
164 this.update(cx, |this, cx| this.reload(cx))?.await
165 })
166 }
167
168 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
169 let id = id.clone();
170 let database_future = ThreadsDatabase::global_future(cx);
171 cx.spawn(async move |this, cx| {
172 let database = database_future.await.map_err(|err| anyhow!(err))?;
173 database.delete_thread(id.clone()).await?;
174
175 this.update(cx, |this, _cx| {
176 this.threads.retain(|thread| thread.id != id)
177 })
178 })
179 }
180
181 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
182 let database_future = ThreadsDatabase::global_future(cx);
183 cx.spawn(async move |this, cx| {
184 let threads = database_future
185 .await
186 .map_err(|err| anyhow!(err))?
187 .list_threads()
188 .await?;
189
190 this.update(cx, |this, cx| {
191 this.threads = threads;
192 cx.notify();
193 })
194 })
195 }
196
197 fn load_default_profile(&self, cx: &Context<Self>) {
198 let assistant_settings = AssistantSettings::get_global(cx);
199
200 self.load_profile_by_id(&assistant_settings.default_profile, cx);
201 }
202
203 pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
204 let assistant_settings = AssistantSettings::get_global(cx);
205
206 if let Some(profile) = assistant_settings.profiles.get(profile_id) {
207 self.load_profile(profile, cx);
208 }
209 }
210
211 pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
212 self.tools.disable_all_tools();
213 self.tools.enable(
214 ToolSource::Native,
215 &profile
216 .tools
217 .iter()
218 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
219 .collect::<Vec<_>>(),
220 );
221
222 if profile.enable_all_context_servers {
223 for context_server in self.context_server_manager.read(cx).all_servers() {
224 self.tools.enable_source(
225 ToolSource::ContextServer {
226 id: context_server.id().into(),
227 },
228 cx,
229 );
230 }
231 } else {
232 for (context_server_id, preset) in &profile.context_servers {
233 self.tools.enable(
234 ToolSource::ContextServer {
235 id: context_server_id.clone().into(),
236 },
237 &preset
238 .tools
239 .iter()
240 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
241 .collect::<Vec<_>>(),
242 )
243 }
244 }
245 }
246
247 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
248 cx.subscribe(
249 &self.context_server_manager.clone(),
250 Self::handle_context_server_event,
251 )
252 .detach();
253 }
254
255 fn handle_context_server_event(
256 &mut self,
257 context_server_manager: Entity<ContextServerManager>,
258 event: &context_server::manager::Event,
259 cx: &mut Context<Self>,
260 ) {
261 let tool_working_set = self.tools.clone();
262 match event {
263 context_server::manager::Event::ServerStarted { server_id } => {
264 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
265 let context_server_manager = context_server_manager.clone();
266 cx.spawn({
267 let server = server.clone();
268 let server_id = server_id.clone();
269 async move |this, cx| {
270 let Some(protocol) = server.client() else {
271 return;
272 };
273
274 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
275 if let Some(tools) = protocol.list_tools().await.log_err() {
276 let tool_ids = tools
277 .tools
278 .into_iter()
279 .map(|tool| {
280 log::info!(
281 "registering context server tool: {:?}",
282 tool.name
283 );
284 tool_working_set.insert(Arc::new(
285 ContextServerTool::new(
286 context_server_manager.clone(),
287 server.id(),
288 tool,
289 ),
290 ))
291 })
292 .collect::<Vec<_>>();
293
294 this.update(cx, |this, cx| {
295 this.context_server_tool_ids.insert(server_id, tool_ids);
296 this.load_default_profile(cx);
297 })
298 .log_err();
299 }
300 }
301 }
302 })
303 .detach();
304 }
305 }
306 context_server::manager::Event::ServerStopped { server_id } => {
307 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
308 tool_working_set.remove(&tool_ids);
309 self.load_default_profile(cx);
310 }
311 }
312 }
313 }
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct SerializedThreadMetadata {
318 pub id: ThreadId,
319 pub summary: SharedString,
320 pub updated_at: DateTime<Utc>,
321}
322
323#[derive(Serialize, Deserialize)]
324pub struct SerializedThread {
325 pub version: String,
326 pub summary: SharedString,
327 pub updated_at: DateTime<Utc>,
328 pub messages: Vec<SerializedMessage>,
329 #[serde(default)]
330 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
331 #[serde(default)]
332 pub cumulative_token_usage: TokenUsage,
333}
334
335impl SerializedThread {
336 pub const VERSION: &'static str = "0.1.0";
337
338 pub fn from_json(json: &[u8]) -> Result<Self> {
339 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
340 match saved_thread_json.get("version") {
341 Some(serde_json::Value::String(version)) => match version.as_str() {
342 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
343 saved_thread_json,
344 )?),
345 _ => Err(anyhow!(
346 "unrecognized serialized thread version: {}",
347 version
348 )),
349 },
350 None => {
351 let saved_thread =
352 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
353 Ok(saved_thread.upgrade())
354 }
355 version => Err(anyhow!(
356 "unrecognized serialized thread version: {:?}",
357 version
358 )),
359 }
360 }
361}
362
363#[derive(Debug, Serialize, Deserialize)]
364pub struct SerializedMessage {
365 pub id: MessageId,
366 pub role: Role,
367 #[serde(default)]
368 pub segments: Vec<SerializedMessageSegment>,
369 #[serde(default)]
370 pub tool_uses: Vec<SerializedToolUse>,
371 #[serde(default)]
372 pub tool_results: Vec<SerializedToolResult>,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376#[serde(tag = "type")]
377pub enum SerializedMessageSegment {
378 #[serde(rename = "text")]
379 Text { text: String },
380 #[serde(rename = "thinking")]
381 Thinking { text: String },
382}
383
384#[derive(Debug, Serialize, Deserialize)]
385pub struct SerializedToolUse {
386 pub id: LanguageModelToolUseId,
387 pub name: SharedString,
388 pub input: serde_json::Value,
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct SerializedToolResult {
393 pub tool_use_id: LanguageModelToolUseId,
394 pub is_error: bool,
395 pub content: Arc<str>,
396}
397
398#[derive(Serialize, Deserialize)]
399struct LegacySerializedThread {
400 pub summary: SharedString,
401 pub updated_at: DateTime<Utc>,
402 pub messages: Vec<LegacySerializedMessage>,
403 #[serde(default)]
404 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
405}
406
407impl LegacySerializedThread {
408 pub fn upgrade(self) -> SerializedThread {
409 SerializedThread {
410 version: SerializedThread::VERSION.to_string(),
411 summary: self.summary,
412 updated_at: self.updated_at,
413 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
414 initial_project_snapshot: self.initial_project_snapshot,
415 cumulative_token_usage: TokenUsage::default(),
416 }
417 }
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421struct LegacySerializedMessage {
422 pub id: MessageId,
423 pub role: Role,
424 pub text: String,
425 #[serde(default)]
426 pub tool_uses: Vec<SerializedToolUse>,
427 #[serde(default)]
428 pub tool_results: Vec<SerializedToolResult>,
429}
430
431impl LegacySerializedMessage {
432 fn upgrade(self) -> SerializedMessage {
433 SerializedMessage {
434 id: self.id,
435 role: self.role,
436 segments: vec![SerializedMessageSegment::Text { text: self.text }],
437 tool_uses: self.tool_uses,
438 tool_results: self.tool_results,
439 }
440 }
441}
442
443struct GlobalThreadsDatabase(
444 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
445);
446
447impl Global for GlobalThreadsDatabase {}
448
449pub(crate) struct ThreadsDatabase {
450 executor: BackgroundExecutor,
451 env: heed::Env,
452 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
453}
454
455impl heed::BytesEncode<'_> for SerializedThread {
456 type EItem = SerializedThread;
457
458 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
459 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
460 }
461}
462
463impl<'a> heed::BytesDecode<'a> for SerializedThread {
464 type DItem = SerializedThread;
465
466 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
467 // We implement this type manually because we want to call `SerializedThread::from_json`,
468 // instead of the Deserialize trait implementation for `SerializedThread`.
469 SerializedThread::from_json(bytes).map_err(Into::into)
470 }
471}
472
473impl ThreadsDatabase {
474 fn global_future(
475 cx: &mut App,
476 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
477 GlobalThreadsDatabase::global(cx).0.clone()
478 }
479
480 fn init(cx: &mut App) {
481 let executor = cx.background_executor().clone();
482 let database_future = executor
483 .spawn({
484 let executor = executor.clone();
485 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
486 async move { ThreadsDatabase::new(database_path, executor) }
487 })
488 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
489 .boxed()
490 .shared();
491
492 cx.set_global(GlobalThreadsDatabase(database_future));
493 }
494
495 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
496 std::fs::create_dir_all(&path)?;
497
498 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
499 let env = unsafe {
500 heed::EnvOpenOptions::new()
501 .map_size(ONE_GB_IN_BYTES)
502 .max_dbs(1)
503 .open(path)?
504 };
505
506 let mut txn = env.write_txn()?;
507 let threads = env.create_database(&mut txn, Some("threads"))?;
508 txn.commit()?;
509
510 Ok(Self {
511 executor,
512 env,
513 threads,
514 })
515 }
516
517 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
518 let env = self.env.clone();
519 let threads = self.threads;
520
521 self.executor.spawn(async move {
522 let txn = env.read_txn()?;
523 let mut iter = threads.iter(&txn)?;
524 let mut threads = Vec::new();
525 while let Some((key, value)) = iter.next().transpose()? {
526 threads.push(SerializedThreadMetadata {
527 id: key,
528 summary: value.summary,
529 updated_at: value.updated_at,
530 });
531 }
532
533 Ok(threads)
534 })
535 }
536
537 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
538 let env = self.env.clone();
539 let threads = self.threads;
540
541 self.executor.spawn(async move {
542 let txn = env.read_txn()?;
543 let thread = threads.get(&txn, &id)?;
544 Ok(thread)
545 })
546 }
547
548 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
549 let env = self.env.clone();
550 let threads = self.threads;
551
552 self.executor.spawn(async move {
553 let mut txn = env.write_txn()?;
554 threads.put(&mut txn, &id, &thread)?;
555 txn.commit()?;
556 Ok(())
557 })
558 }
559
560 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
561 let env = self.env.clone();
562 let threads = self.threads;
563
564 self.executor.spawn(async move {
565 let mut txn = env.write_txn()?;
566 threads.delete(&mut txn, &id)?;
567 txn.commit()?;
568 Ok(())
569 })
570 }
571}