1use std::borrow::Cow;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
8use chrono::{DateTime, Utc};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
12use futures::FutureExt as _;
13use futures::future::{self, BoxFuture, Shared};
14use gpui::{
15 App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
16 prelude::*,
17};
18use heed::Database;
19use heed::types::SerdeBincode;
20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
21use project::Project;
22use prompt_store::PromptBuilder;
23use serde::{Deserialize, Serialize};
24use settings::{Settings as _, SettingsStore};
25use util::ResultExt as _;
26
27use crate::thread::{
28 DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
29};
30
31pub fn init(cx: &mut App) {
32 ThreadsDatabase::init(cx);
33}
34
35pub struct ThreadStore {
36 project: Entity<Project>,
37 tools: Arc<ToolWorkingSet>,
38 prompt_builder: Arc<PromptBuilder>,
39 context_server_manager: Entity<ContextServerManager>,
40 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
41 threads: Vec<SerializedThreadMetadata>,
42 _subscriptions: Vec<Subscription>,
43}
44
45impl ThreadStore {
46 pub fn new(
47 project: Entity<Project>,
48 tools: Arc<ToolWorkingSet>,
49 prompt_builder: Arc<PromptBuilder>,
50 cx: &mut App,
51 ) -> Result<Entity<Self>> {
52 let this = cx.new(|cx| {
53 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
54 let context_server_manager = cx.new(|cx| {
55 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
56 });
57 let settings_subscription =
58 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
59 this.load_default_profile(cx);
60 });
61
62 let this = Self {
63 project,
64 tools,
65 prompt_builder,
66 context_server_manager,
67 context_server_tool_ids: HashMap::default(),
68 threads: Vec::new(),
69 _subscriptions: vec![settings_subscription],
70 };
71 this.load_default_profile(cx);
72 this.register_context_server_handlers(cx);
73 this.reload(cx).detach_and_log_err(cx);
74
75 this
76 });
77
78 Ok(this)
79 }
80
81 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
82 self.context_server_manager.clone()
83 }
84
85 pub fn tools(&self) -> Arc<ToolWorkingSet> {
86 self.tools.clone()
87 }
88
89 /// Returns the number of threads.
90 pub fn thread_count(&self) -> usize {
91 self.threads.len()
92 }
93
94 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
95 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
96 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
97 threads
98 }
99
100 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
101 self.threads().into_iter().take(limit).collect()
102 }
103
104 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
105 cx.new(|cx| {
106 Thread::new(
107 self.project.clone(),
108 self.tools.clone(),
109 self.prompt_builder.clone(),
110 cx,
111 )
112 })
113 }
114
115 pub fn open_thread(
116 &self,
117 id: &ThreadId,
118 cx: &mut Context<Self>,
119 ) -> Task<Result<Entity<Thread>>> {
120 let id = id.clone();
121 let database_future = ThreadsDatabase::global_future(cx);
122 cx.spawn(async move |this, cx| {
123 let database = database_future.await.map_err(|err| anyhow!(err))?;
124 let thread = database
125 .try_find_thread(id.clone())
126 .await?
127 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
128
129 let thread = this.update(cx, |this, cx| {
130 cx.new(|cx| {
131 Thread::deserialize(
132 id.clone(),
133 thread,
134 this.project.clone(),
135 this.tools.clone(),
136 this.prompt_builder.clone(),
137 cx,
138 )
139 })
140 })?;
141
142 let (system_prompt_context, load_error) = thread
143 .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
144 .await;
145 thread.update(cx, |thread, cx| {
146 thread.set_system_prompt_context(system_prompt_context);
147 if let Some(load_error) = load_error {
148 cx.emit(ThreadEvent::ShowError(load_error));
149 }
150 })?;
151
152 Ok(thread)
153 })
154 }
155
156 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
157 let (metadata, serialized_thread) =
158 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
159
160 let database_future = ThreadsDatabase::global_future(cx);
161 cx.spawn(async move |this, cx| {
162 let serialized_thread = serialized_thread.await?;
163 let database = database_future.await.map_err(|err| anyhow!(err))?;
164 database.save_thread(metadata, serialized_thread).await?;
165
166 this.update(cx, |this, cx| this.reload(cx))?.await
167 })
168 }
169
170 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
171 let id = id.clone();
172 let database_future = ThreadsDatabase::global_future(cx);
173 cx.spawn(async move |this, cx| {
174 let database = database_future.await.map_err(|err| anyhow!(err))?;
175 database.delete_thread(id.clone()).await?;
176
177 this.update(cx, |this, _cx| {
178 this.threads.retain(|thread| thread.id != id)
179 })
180 })
181 }
182
183 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
184 let database_future = ThreadsDatabase::global_future(cx);
185 cx.spawn(async move |this, cx| {
186 let threads = database_future
187 .await
188 .map_err(|err| anyhow!(err))?
189 .list_threads()
190 .await?;
191
192 this.update(cx, |this, cx| {
193 this.threads = threads;
194 cx.notify();
195 })
196 })
197 }
198
199 fn load_default_profile(&self, cx: &Context<Self>) {
200 let assistant_settings = AssistantSettings::get_global(cx);
201
202 self.load_profile_by_id(&assistant_settings.default_profile, cx);
203 }
204
205 pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
206 let assistant_settings = AssistantSettings::get_global(cx);
207
208 if let Some(profile) = assistant_settings.profiles.get(profile_id) {
209 self.load_profile(profile, cx);
210 }
211 }
212
213 pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
214 self.tools.disable_all_tools();
215 self.tools.enable(
216 ToolSource::Native,
217 &profile
218 .tools
219 .iter()
220 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
221 .collect::<Vec<_>>(),
222 );
223
224 if profile.enable_all_context_servers {
225 for context_server in self.context_server_manager.read(cx).all_servers() {
226 self.tools.enable_source(
227 ToolSource::ContextServer {
228 id: context_server.id().into(),
229 },
230 cx,
231 );
232 }
233 } else {
234 for (context_server_id, preset) in &profile.context_servers {
235 self.tools.enable(
236 ToolSource::ContextServer {
237 id: context_server_id.clone().into(),
238 },
239 &preset
240 .tools
241 .iter()
242 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
243 .collect::<Vec<_>>(),
244 )
245 }
246 }
247 }
248
249 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
250 cx.subscribe(
251 &self.context_server_manager.clone(),
252 Self::handle_context_server_event,
253 )
254 .detach();
255 }
256
257 fn handle_context_server_event(
258 &mut self,
259 context_server_manager: Entity<ContextServerManager>,
260 event: &context_server::manager::Event,
261 cx: &mut Context<Self>,
262 ) {
263 let tool_working_set = self.tools.clone();
264 match event {
265 context_server::manager::Event::ServerStarted { server_id } => {
266 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
267 let context_server_manager = context_server_manager.clone();
268 cx.spawn({
269 let server = server.clone();
270 let server_id = server_id.clone();
271 async move |this, cx| {
272 let Some(protocol) = server.client() else {
273 return;
274 };
275
276 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
277 if let Some(tools) = protocol.list_tools().await.log_err() {
278 let tool_ids = tools
279 .tools
280 .into_iter()
281 .map(|tool| {
282 log::info!(
283 "registering context server tool: {:?}",
284 tool.name
285 );
286 tool_working_set.insert(Arc::new(
287 ContextServerTool::new(
288 context_server_manager.clone(),
289 server.id(),
290 tool,
291 ),
292 ))
293 })
294 .collect::<Vec<_>>();
295
296 this.update(cx, |this, cx| {
297 this.context_server_tool_ids.insert(server_id, tool_ids);
298 this.load_default_profile(cx);
299 })
300 .log_err();
301 }
302 }
303 }
304 })
305 .detach();
306 }
307 }
308 context_server::manager::Event::ServerStopped { server_id } => {
309 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
310 tool_working_set.remove(&tool_ids);
311 self.load_default_profile(cx);
312 }
313 }
314 }
315 }
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct SerializedThreadMetadata {
320 pub id: ThreadId,
321 pub summary: SharedString,
322 pub updated_at: DateTime<Utc>,
323}
324
325#[derive(Serialize, Deserialize, Debug)]
326pub struct SerializedThread {
327 pub version: String,
328 pub summary: SharedString,
329 pub updated_at: DateTime<Utc>,
330 pub messages: Vec<SerializedMessage>,
331 #[serde(default)]
332 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
333 #[serde(default)]
334 pub cumulative_token_usage: TokenUsage,
335 #[serde(default)]
336 pub detailed_summary_state: DetailedSummaryState,
337}
338
339impl SerializedThread {
340 pub const VERSION: &'static str = "0.1.0";
341
342 pub fn from_json(json: &[u8]) -> Result<Self> {
343 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
344 match saved_thread_json.get("version") {
345 Some(serde_json::Value::String(version)) => match version.as_str() {
346 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
347 saved_thread_json,
348 )?),
349 _ => Err(anyhow!(
350 "unrecognized serialized thread version: {}",
351 version
352 )),
353 },
354 None => {
355 let saved_thread =
356 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
357 Ok(saved_thread.upgrade())
358 }
359 version => Err(anyhow!(
360 "unrecognized serialized thread version: {:?}",
361 version
362 )),
363 }
364 }
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct SerializedMessage {
369 pub id: MessageId,
370 pub role: Role,
371 #[serde(default)]
372 pub segments: Vec<SerializedMessageSegment>,
373 #[serde(default)]
374 pub tool_uses: Vec<SerializedToolUse>,
375 #[serde(default)]
376 pub tool_results: Vec<SerializedToolResult>,
377}
378
379#[derive(Debug, Serialize, Deserialize)]
380#[serde(tag = "type")]
381pub enum SerializedMessageSegment {
382 #[serde(rename = "text")]
383 Text { text: String },
384 #[serde(rename = "thinking")]
385 Thinking { text: String },
386}
387
388#[derive(Debug, Serialize, Deserialize)]
389pub struct SerializedToolUse {
390 pub id: LanguageModelToolUseId,
391 pub name: SharedString,
392 pub input: serde_json::Value,
393}
394
395#[derive(Debug, Serialize, Deserialize)]
396pub struct SerializedToolResult {
397 pub tool_use_id: LanguageModelToolUseId,
398 pub is_error: bool,
399 pub content: Arc<str>,
400}
401
402#[derive(Serialize, Deserialize)]
403struct LegacySerializedThread {
404 pub summary: SharedString,
405 pub updated_at: DateTime<Utc>,
406 pub messages: Vec<LegacySerializedMessage>,
407 #[serde(default)]
408 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
409}
410
411impl LegacySerializedThread {
412 pub fn upgrade(self) -> SerializedThread {
413 SerializedThread {
414 version: SerializedThread::VERSION.to_string(),
415 summary: self.summary,
416 updated_at: self.updated_at,
417 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
418 initial_project_snapshot: self.initial_project_snapshot,
419 cumulative_token_usage: TokenUsage::default(),
420 detailed_summary_state: DetailedSummaryState::default(),
421 }
422 }
423}
424
425#[derive(Debug, Serialize, Deserialize)]
426struct LegacySerializedMessage {
427 pub id: MessageId,
428 pub role: Role,
429 pub text: String,
430 #[serde(default)]
431 pub tool_uses: Vec<SerializedToolUse>,
432 #[serde(default)]
433 pub tool_results: Vec<SerializedToolResult>,
434}
435
436impl LegacySerializedMessage {
437 fn upgrade(self) -> SerializedMessage {
438 SerializedMessage {
439 id: self.id,
440 role: self.role,
441 segments: vec![SerializedMessageSegment::Text { text: self.text }],
442 tool_uses: self.tool_uses,
443 tool_results: self.tool_results,
444 }
445 }
446}
447
448struct GlobalThreadsDatabase(
449 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
450);
451
452impl Global for GlobalThreadsDatabase {}
453
454pub(crate) struct ThreadsDatabase {
455 executor: BackgroundExecutor,
456 env: heed::Env,
457 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
458}
459
460impl heed::BytesEncode<'_> for SerializedThread {
461 type EItem = SerializedThread;
462
463 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
464 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
465 }
466}
467
468impl<'a> heed::BytesDecode<'a> for SerializedThread {
469 type DItem = SerializedThread;
470
471 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
472 // We implement this type manually because we want to call `SerializedThread::from_json`,
473 // instead of the Deserialize trait implementation for `SerializedThread`.
474 SerializedThread::from_json(bytes).map_err(Into::into)
475 }
476}
477
478impl ThreadsDatabase {
479 fn global_future(
480 cx: &mut App,
481 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
482 GlobalThreadsDatabase::global(cx).0.clone()
483 }
484
485 fn init(cx: &mut App) {
486 let executor = cx.background_executor().clone();
487 let database_future = executor
488 .spawn({
489 let executor = executor.clone();
490 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
491 async move { ThreadsDatabase::new(database_path, executor) }
492 })
493 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
494 .boxed()
495 .shared();
496
497 cx.set_global(GlobalThreadsDatabase(database_future));
498 }
499
500 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
501 std::fs::create_dir_all(&path)?;
502
503 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
504 let env = unsafe {
505 heed::EnvOpenOptions::new()
506 .map_size(ONE_GB_IN_BYTES)
507 .max_dbs(1)
508 .open(path)?
509 };
510
511 let mut txn = env.write_txn()?;
512 let threads = env.create_database(&mut txn, Some("threads"))?;
513 txn.commit()?;
514
515 Ok(Self {
516 executor,
517 env,
518 threads,
519 })
520 }
521
522 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
523 let env = self.env.clone();
524 let threads = self.threads;
525
526 self.executor.spawn(async move {
527 let txn = env.read_txn()?;
528 let mut iter = threads.iter(&txn)?;
529 let mut threads = Vec::new();
530 while let Some((key, value)) = iter.next().transpose()? {
531 threads.push(SerializedThreadMetadata {
532 id: key,
533 summary: value.summary,
534 updated_at: value.updated_at,
535 });
536 }
537
538 Ok(threads)
539 })
540 }
541
542 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
543 let env = self.env.clone();
544 let threads = self.threads;
545
546 self.executor.spawn(async move {
547 let txn = env.read_txn()?;
548 let thread = threads.get(&txn, &id)?;
549 Ok(thread)
550 })
551 }
552
553 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
554 let env = self.env.clone();
555 let threads = self.threads;
556
557 self.executor.spawn(async move {
558 let mut txn = env.write_txn()?;
559 threads.put(&mut txn, &id, &thread)?;
560 txn.commit()?;
561 Ok(())
562 })
563 }
564
565 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
566 let env = self.env.clone();
567 let threads = self.threads;
568
569 self.executor.spawn(async move {
570 let mut txn = env.write_txn()?;
571 threads.delete(&mut txn, &id)?;
572 txn.commit()?;
573 Ok(())
574 })
575 }
576}