thread_store.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{ToolId, ToolWorkingSet};
  5use collections::HashMap;
  6use context_server::manager::ContextServerManager;
  7use context_server::{ContextServerFactoryRegistry, ContextServerTool};
  8use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
  9use project::Project;
 10use unindent::Unindent;
 11use util::ResultExt as _;
 12
 13use crate::thread::{Thread, ThreadId};
 14
 15pub struct ThreadStore {
 16    #[allow(unused)]
 17    project: Model<Project>,
 18    tools: Arc<ToolWorkingSet>,
 19    context_server_manager: Model<ContextServerManager>,
 20    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 21    threads: Vec<Model<Thread>>,
 22}
 23
 24impl ThreadStore {
 25    pub fn new(
 26        project: Model<Project>,
 27        tools: Arc<ToolWorkingSet>,
 28        cx: &mut AppContext,
 29    ) -> Task<Result<Model<Self>>> {
 30        cx.spawn(|mut cx| async move {
 31            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
 32                let context_server_factory_registry =
 33                    ContextServerFactoryRegistry::default_global(cx);
 34                let context_server_manager = cx.new_model(|cx| {
 35                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 36                });
 37
 38                let mut this = Self {
 39                    project,
 40                    tools,
 41                    context_server_manager,
 42                    context_server_tool_ids: HashMap::default(),
 43                    threads: Vec::new(),
 44                };
 45                this.mock_recent_threads(cx);
 46                this.register_context_server_handlers(cx);
 47
 48                this
 49            })?;
 50
 51            Ok(this)
 52        })
 53    }
 54
 55    pub fn threads(&self, cx: &ModelContext<Self>) -> Vec<Model<Thread>> {
 56        let mut threads = self
 57            .threads
 58            .iter()
 59            .filter(|thread| !thread.read(cx).is_empty())
 60            .cloned()
 61            .collect::<Vec<_>>();
 62        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.read(cx).updated_at()));
 63        threads
 64    }
 65
 66    pub fn recent_threads(&self, limit: usize, cx: &ModelContext<Self>) -> Vec<Model<Thread>> {
 67        self.threads(cx).into_iter().take(limit).collect()
 68    }
 69
 70    pub fn create_thread(&mut self, cx: &mut ModelContext<Self>) -> Model<Thread> {
 71        let thread = cx.new_model(|cx| Thread::new(self.tools.clone(), cx));
 72        self.threads.push(thread.clone());
 73        thread
 74    }
 75
 76    pub fn open_thread(&self, id: &ThreadId, cx: &mut ModelContext<Self>) -> Option<Model<Thread>> {
 77        self.threads
 78            .iter()
 79            .find(|thread| thread.read(cx).id() == id)
 80            .cloned()
 81    }
 82
 83    fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
 84        cx.subscribe(
 85            &self.context_server_manager.clone(),
 86            Self::handle_context_server_event,
 87        )
 88        .detach();
 89    }
 90
 91    fn handle_context_server_event(
 92        &mut self,
 93        context_server_manager: Model<ContextServerManager>,
 94        event: &context_server::manager::Event,
 95        cx: &mut ModelContext<Self>,
 96    ) {
 97        let tool_working_set = self.tools.clone();
 98        match event {
 99            context_server::manager::Event::ServerStarted { server_id } => {
100                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
101                    let context_server_manager = context_server_manager.clone();
102                    cx.spawn({
103                        let server = server.clone();
104                        let server_id = server_id.clone();
105                        |this, mut cx| async move {
106                            let Some(protocol) = server.client() else {
107                                return;
108                            };
109
110                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
111                                if let Some(tools) = protocol.list_tools().await.log_err() {
112                                    let tool_ids = tools
113                                        .tools
114                                        .into_iter()
115                                        .map(|tool| {
116                                            log::info!(
117                                                "registering context server tool: {:?}",
118                                                tool.name
119                                            );
120                                            tool_working_set.insert(Arc::new(
121                                                ContextServerTool::new(
122                                                    context_server_manager.clone(),
123                                                    server.id(),
124                                                    tool,
125                                                ),
126                                            ))
127                                        })
128                                        .collect::<Vec<_>>();
129
130                                    this.update(&mut cx, |this, _cx| {
131                                        this.context_server_tool_ids.insert(server_id, tool_ids);
132                                    })
133                                    .log_err();
134                                }
135                            }
136                        }
137                    })
138                    .detach();
139                }
140            }
141            context_server::manager::Event::ServerStopped { server_id } => {
142                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
143                    tool_working_set.remove(&tool_ids);
144                }
145            }
146        }
147    }
148}
149
150impl ThreadStore {
151    /// Creates some mocked recent threads for testing purposes.
152    fn mock_recent_threads(&mut self, cx: &mut ModelContext<Self>) {
153        use language_model::Role;
154
155        self.threads.push(cx.new_model(|cx| {
156            let mut thread = Thread::new(self.tools.clone(), cx);
157            thread.set_summary("Introduction to quantum computing", cx);
158            thread.insert_user_message("Hello! Can you help me understand quantum computing?", cx);
159            thread.insert_message(Role::Assistant, "Of course! I'd be happy to help you understand quantum computing. Quantum computing is a fascinating field that uses the principles of quantum mechanics to process information. Unlike classical computers that use bits (0s and 1s), quantum computers use quantum bits or 'qubits'. These qubits can exist in multiple states simultaneously, a property called superposition. This allows quantum computers to perform certain calculations much faster than classical computers. What specific aspect of quantum computing would you like to know more about?", cx);
160            thread.insert_user_message("That's interesting! Can you explain how quantum entanglement is used in quantum computing?", cx);
161            thread.insert_message(Role::Assistant, "Certainly! Quantum entanglement is a key principle used in quantum computing. When two qubits become entangled, the state of one qubit is directly related to the state of the other, regardless of the distance between them. This property is used in quantum computing to create complex quantum states and to perform operations on multiple qubits simultaneously. Entanglement allows quantum computers to process information in ways that classical computers cannot, potentially solving certain problems much more efficiently. For example, it's crucial in quantum error correction and in algorithms like quantum teleportation, which is important for quantum communication.", cx);
162            thread
163        }));
164
165        self.threads.push(cx.new_model(|cx| {
166            let mut thread = Thread::new(self.tools.clone(), cx);
167            thread.set_summary("Rust web development and async programming", cx);
168            thread.insert_user_message("Can you show me an example of Rust code for a simple web server?", cx);
169            thread.insert_message(Role::Assistant, "Certainly! Here's an example of a simple web server in Rust using the `actix-web` framework:
170
171        ```rust
172        use actix_web::{web, App, HttpResponse, HttpServer, Responder};
173
174        async fn hello() -> impl Responder {
175            HttpResponse::Ok().body(\"Hello, World!\")
176        }
177
178        #[actix_web::main]
179        async fn main() -> std::io::Result<()> {
180            HttpServer::new(|| {
181                App::new()
182                    .route(\"/\", web::get().to(hello))
183            })
184            .bind(\"127.0.0.1:8080\")?
185            .run()
186            .await
187        }
188        ```
189
190        This code creates a basic web server that responds with 'Hello, World!' when you access the root URL. Here's a breakdown of what's happening:
191
192        1. We import necessary items from the `actix-web` crate.
193        2. We define an async `hello` function that returns a simple HTTP response.
194        3. In the `main` function, we set up the server to listen on `127.0.0.1:8080`.
195        4. We configure the app to respond to GET requests on the root path with our `hello` function.
196
197        To run this, you'd need to add `actix-web` to your `Cargo.toml` dependencies:
198
199        ```toml
200        [dependencies]
201        actix-web = \"4.0\"
202        ```
203
204        Then you can run the server with `cargo run` and access it at `http://localhost:8080`.".unindent(), cx);
205            thread.insert_user_message("That's great! Can you explain more about async functions in Rust?", cx);
206            thread.insert_message(Role::Assistant, "Certainly! Async functions are a key feature in Rust for writing efficient, non-blocking code, especially for I/O-bound operations. Here's an overview:
207
208        1. **Syntax**: Async functions are declared using the `async` keyword:
209
210           ```rust
211           async fn my_async_function() -> Result<(), Error> {
212               // Asynchronous code here
213           }
214           ```
215
216        2. **Futures**: Async functions return a `Future`. A `Future` represents a value that may not be available yet but will be at some point.
217
218        3. **Await**: Inside an async function, you can use the `.await` syntax to wait for other async operations to complete:
219
220           ```rust
221           async fn fetch_data() -> Result<String, Error> {
222               let response = make_http_request().await?;
223               let data = process_response(response).await?;
224               Ok(data)
225           }
226           ```
227
228        4. **Non-blocking**: Async functions allow the runtime to work on other tasks while waiting for I/O or other operations to complete, making efficient use of system resources.
229
230        5. **Runtime**: To execute async code, you need a runtime like `tokio` or `async-std`. Actix-web, which we used in the previous example, includes its own runtime.
231
232        6. **Error Handling**: Async functions work well with Rust's `?` operator for error handling.
233
234        Async programming in Rust provides a powerful way to write concurrent code that's both safe and efficient. It's particularly useful for servers, network programming, and any application that deals with many concurrent operations.".unindent(), cx);
235            thread
236        }));
237    }
238}