1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::AssistantPanel;
14
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
22pub(crate) use inline_assistant::*;
23pub(crate) use model_selector::*;
24use rustdoc::RustdocStore;
25use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
26use serde::{Deserialize, Serialize};
27use settings::{Settings, SettingsStore};
28use slash_command::{
29 active_command, default_command, fetch_command, file_command, now_command, project_command,
30 prompt_command, rustdoc_command, search_command, tabs_command,
31};
32use std::{
33 fmt::{self, Display},
34 sync::Arc,
35};
36pub(crate) use streaming_diff::*;
37use util::paths::EMBEDDINGS_DIR;
38
39actions!(
40 assistant,
41 [
42 Assist,
43 Split,
44 CycleMessageRole,
45 QuoteSelection,
46 ToggleFocus,
47 ResetKey,
48 InlineAssist,
49 InsertActivePrompt,
50 ToggleHistory,
51 ApplyEdit,
52 ConfirmCommand,
53 ToggleModelSelector
54 ]
55);
56
57#[derive(
58 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
59)]
60struct MessageId(usize);
61
62#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
63#[serde(rename_all = "lowercase")]
64pub enum Role {
65 User,
66 Assistant,
67 System,
68}
69
70impl Role {
71 pub fn cycle(&mut self) {
72 *self = match self {
73 Role::User => Role::Assistant,
74 Role::Assistant => Role::System,
75 Role::System => Role::User,
76 }
77 }
78}
79
80impl Display for Role {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Role::User => write!(f, "user"),
84 Role::Assistant => write!(f, "assistant"),
85 Role::System => write!(f, "system"),
86 }
87 }
88}
89
90#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
91pub enum LanguageModel {
92 Cloud(CloudModel),
93 OpenAi(OpenAiModel),
94 Anthropic(AnthropicModel),
95 Ollama(OllamaModel),
96}
97
98impl Default for LanguageModel {
99 fn default() -> Self {
100 LanguageModel::Cloud(CloudModel::default())
101 }
102}
103
104impl LanguageModel {
105 pub fn telemetry_id(&self) -> String {
106 match self {
107 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
108 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
109 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
110 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
111 }
112 }
113
114 pub fn display_name(&self) -> String {
115 match self {
116 LanguageModel::OpenAi(model) => model.display_name().into(),
117 LanguageModel::Anthropic(model) => model.display_name().into(),
118 LanguageModel::Cloud(model) => model.display_name().into(),
119 LanguageModel::Ollama(model) => model.display_name().into(),
120 }
121 }
122
123 pub fn max_token_count(&self) -> usize {
124 match self {
125 LanguageModel::OpenAi(model) => model.max_token_count(),
126 LanguageModel::Anthropic(model) => model.max_token_count(),
127 LanguageModel::Cloud(model) => model.max_token_count(),
128 LanguageModel::Ollama(model) => model.max_token_count(),
129 }
130 }
131
132 pub fn id(&self) -> &str {
133 match self {
134 LanguageModel::OpenAi(model) => model.id(),
135 LanguageModel::Anthropic(model) => model.id(),
136 LanguageModel::Cloud(model) => model.id(),
137 LanguageModel::Ollama(model) => model.id(),
138 }
139 }
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143pub struct LanguageModelRequestMessage {
144 pub role: Role,
145 pub content: String,
146}
147
148impl LanguageModelRequestMessage {
149 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
150 proto::LanguageModelRequestMessage {
151 role: match self.role {
152 Role::User => proto::LanguageModelRole::LanguageModelUser,
153 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
154 Role::System => proto::LanguageModelRole::LanguageModelSystem,
155 } as i32,
156 content: self.content.clone(),
157 tool_calls: Vec::new(),
158 tool_call_id: None,
159 }
160 }
161}
162
163#[derive(Debug, Default, Serialize)]
164pub struct LanguageModelRequest {
165 pub model: LanguageModel,
166 pub messages: Vec<LanguageModelRequestMessage>,
167 pub stop: Vec<String>,
168 pub temperature: f32,
169}
170
171impl LanguageModelRequest {
172 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
173 proto::CompleteWithLanguageModel {
174 model: self.model.id().to_string(),
175 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
176 stop: self.stop.clone(),
177 temperature: self.temperature,
178 tool_choice: None,
179 tools: Vec::new(),
180 }
181 }
182
183 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
184 pub fn preprocess(&mut self) {
185 match &self.model {
186 LanguageModel::OpenAi(_) => {}
187 LanguageModel::Anthropic(_) => {}
188 LanguageModel::Ollama(_) => {}
189 LanguageModel::Cloud(model) => match model {
190 CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
191 preprocess_anthropic_request(self);
192 }
193 _ => {}
194 },
195 }
196 }
197}
198
199#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
200pub struct LanguageModelResponseMessage {
201 pub role: Option<Role>,
202 pub content: Option<String>,
203}
204
205#[derive(Deserialize, Debug)]
206pub struct LanguageModelUsage {
207 pub prompt_tokens: u32,
208 pub completion_tokens: u32,
209 pub total_tokens: u32,
210}
211
212#[derive(Deserialize, Debug)]
213pub struct LanguageModelChoiceDelta {
214 pub index: u32,
215 pub delta: LanguageModelResponseMessage,
216 pub finish_reason: Option<String>,
217}
218
219#[derive(Clone, Debug, Serialize, Deserialize)]
220struct MessageMetadata {
221 role: Role,
222 status: MessageStatus,
223}
224
225#[derive(Clone, Debug, Serialize, Deserialize)]
226enum MessageStatus {
227 Pending,
228 Done,
229 Error(SharedString),
230}
231
232/// The state pertaining to the Assistant.
233#[derive(Default)]
234struct Assistant {
235 /// Whether the Assistant is enabled.
236 enabled: bool,
237}
238
239impl Global for Assistant {}
240
241impl Assistant {
242 const NAMESPACE: &'static str = "assistant";
243
244 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
245 if self.enabled == enabled {
246 return;
247 }
248
249 self.enabled = enabled;
250
251 if !enabled {
252 CommandPaletteFilter::update_global(cx, |filter, _cx| {
253 filter.hide_namespace(Self::NAMESPACE);
254 });
255
256 return;
257 }
258
259 CommandPaletteFilter::update_global(cx, |filter, _cx| {
260 filter.show_namespace(Self::NAMESPACE);
261 });
262 }
263}
264
265pub fn init(client: Arc<Client>, cx: &mut AppContext) {
266 cx.set_global(Assistant::default());
267 AssistantSettings::register(cx);
268
269 cx.spawn(|mut cx| {
270 let client = client.clone();
271 async move {
272 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
273 let semantic_index = SemanticIndex::new(
274 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
275 Arc::new(embedding_provider),
276 &mut cx,
277 )
278 .await?;
279 cx.update(|cx| cx.set_global(semantic_index))
280 }
281 })
282 .detach();
283
284 prompt_library::init(cx);
285 completion_provider::init(client.clone(), cx);
286 assistant_slash_command::init(cx);
287 register_slash_commands(cx);
288 assistant_panel::init(cx);
289 inline_assistant::init(client.telemetry().clone(), cx);
290 RustdocStore::init_global(cx);
291
292 CommandPaletteFilter::update_global(cx, |filter, _cx| {
293 filter.hide_namespace(Assistant::NAMESPACE);
294 });
295 Assistant::update_global(cx, |assistant, cx| {
296 let settings = AssistantSettings::get_global(cx);
297
298 assistant.set_enabled(settings.enabled, cx);
299 });
300 cx.observe_global::<SettingsStore>(|cx| {
301 Assistant::update_global(cx, |assistant, cx| {
302 let settings = AssistantSettings::get_global(cx);
303 assistant.set_enabled(settings.enabled, cx);
304 });
305 })
306 .detach();
307}
308
309fn register_slash_commands(cx: &mut AppContext) {
310 let slash_command_registry = SlashCommandRegistry::global(cx);
311 slash_command_registry.register_command(file_command::FileSlashCommand, true);
312 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
313 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
314 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
315 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
316 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
317 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
318 slash_command_registry.register_command(now_command::NowSlashCommand, true);
319 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
320 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
321}
322
323#[cfg(test)]
324#[ctor::ctor]
325fn init_logger() {
326 if std::env::var("RUST_LOG").is_ok() {
327 env_logger::init();
328 }
329}