1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::AssistantPanel;
14
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
22pub(crate) use inline_assistant::*;
23pub(crate) use model_selector::*;
24use rustdoc::RustdocStore;
25use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
26use serde::{Deserialize, Serialize};
27use settings::{Settings, SettingsStore};
28use slash_command::{
29 active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
30 project_command, prompt_command, rustdoc_command, search_command, tabs_command,
31};
32use std::{
33 fmt::{self, Display},
34 sync::Arc,
35};
36pub(crate) use streaming_diff::*;
37
38actions!(
39 assistant,
40 [
41 Assist,
42 Split,
43 CycleMessageRole,
44 QuoteSelection,
45 ToggleFocus,
46 ResetKey,
47 InlineAssist,
48 InsertActivePrompt,
49 ToggleHistory,
50 ApplyEdit,
51 ConfirmCommand,
52 ToggleModelSelector
53 ]
54);
55
56#[derive(
57 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
58)]
59struct MessageId(usize);
60
61#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
62#[serde(rename_all = "lowercase")]
63pub enum Role {
64 User,
65 Assistant,
66 System,
67}
68
69impl Role {
70 pub fn cycle(&mut self) {
71 *self = match self {
72 Role::User => Role::Assistant,
73 Role::Assistant => Role::System,
74 Role::System => Role::User,
75 }
76 }
77}
78
79impl Display for Role {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Role::User => write!(f, "user"),
83 Role::Assistant => write!(f, "assistant"),
84 Role::System => write!(f, "system"),
85 }
86 }
87}
88
89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
90pub enum LanguageModel {
91 Cloud(CloudModel),
92 OpenAi(OpenAiModel),
93 Anthropic(AnthropicModel),
94 Ollama(OllamaModel),
95}
96
97impl Default for LanguageModel {
98 fn default() -> Self {
99 LanguageModel::Cloud(CloudModel::default())
100 }
101}
102
103impl LanguageModel {
104 pub fn telemetry_id(&self) -> String {
105 match self {
106 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
107 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
108 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
109 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
110 }
111 }
112
113 pub fn display_name(&self) -> String {
114 match self {
115 LanguageModel::OpenAi(model) => model.display_name().into(),
116 LanguageModel::Anthropic(model) => model.display_name().into(),
117 LanguageModel::Cloud(model) => model.display_name().into(),
118 LanguageModel::Ollama(model) => model.display_name().into(),
119 }
120 }
121
122 pub fn max_token_count(&self) -> usize {
123 match self {
124 LanguageModel::OpenAi(model) => model.max_token_count(),
125 LanguageModel::Anthropic(model) => model.max_token_count(),
126 LanguageModel::Cloud(model) => model.max_token_count(),
127 LanguageModel::Ollama(model) => model.max_token_count(),
128 }
129 }
130
131 pub fn id(&self) -> &str {
132 match self {
133 LanguageModel::OpenAi(model) => model.id(),
134 LanguageModel::Anthropic(model) => model.id(),
135 LanguageModel::Cloud(model) => model.id(),
136 LanguageModel::Ollama(model) => model.id(),
137 }
138 }
139}
140
141#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
142pub struct LanguageModelRequestMessage {
143 pub role: Role,
144 pub content: String,
145}
146
147impl LanguageModelRequestMessage {
148 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
149 proto::LanguageModelRequestMessage {
150 role: match self.role {
151 Role::User => proto::LanguageModelRole::LanguageModelUser,
152 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
153 Role::System => proto::LanguageModelRole::LanguageModelSystem,
154 } as i32,
155 content: self.content.clone(),
156 tool_calls: Vec::new(),
157 tool_call_id: None,
158 }
159 }
160}
161
162#[derive(Debug, Default, Serialize)]
163pub struct LanguageModelRequest {
164 pub model: LanguageModel,
165 pub messages: Vec<LanguageModelRequestMessage>,
166 pub stop: Vec<String>,
167 pub temperature: f32,
168}
169
170impl LanguageModelRequest {
171 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
172 proto::CompleteWithLanguageModel {
173 model: self.model.id().to_string(),
174 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
175 stop: self.stop.clone(),
176 temperature: self.temperature,
177 tool_choice: None,
178 tools: Vec::new(),
179 }
180 }
181
182 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
183 pub fn preprocess(&mut self) {
184 match &self.model {
185 LanguageModel::OpenAi(_) => {}
186 LanguageModel::Anthropic(_) => {}
187 LanguageModel::Ollama(_) => {}
188 LanguageModel::Cloud(model) => match model {
189 CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
190 preprocess_anthropic_request(self);
191 }
192 _ => {}
193 },
194 }
195 }
196}
197
198#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
199pub struct LanguageModelResponseMessage {
200 pub role: Option<Role>,
201 pub content: Option<String>,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct LanguageModelUsage {
206 pub prompt_tokens: u32,
207 pub completion_tokens: u32,
208 pub total_tokens: u32,
209}
210
211#[derive(Deserialize, Debug)]
212pub struct LanguageModelChoiceDelta {
213 pub index: u32,
214 pub delta: LanguageModelResponseMessage,
215 pub finish_reason: Option<String>,
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize)]
219struct MessageMetadata {
220 role: Role,
221 status: MessageStatus,
222}
223
224#[derive(Clone, Debug, Serialize, Deserialize)]
225enum MessageStatus {
226 Pending,
227 Done,
228 Error(SharedString),
229}
230
231/// The state pertaining to the Assistant.
232#[derive(Default)]
233struct Assistant {
234 /// Whether the Assistant is enabled.
235 enabled: bool,
236}
237
238impl Global for Assistant {}
239
240impl Assistant {
241 const NAMESPACE: &'static str = "assistant";
242
243 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
244 if self.enabled == enabled {
245 return;
246 }
247
248 self.enabled = enabled;
249
250 if !enabled {
251 CommandPaletteFilter::update_global(cx, |filter, _cx| {
252 filter.hide_namespace(Self::NAMESPACE);
253 });
254
255 return;
256 }
257
258 CommandPaletteFilter::update_global(cx, |filter, _cx| {
259 filter.show_namespace(Self::NAMESPACE);
260 });
261 }
262}
263
264pub fn init(client: Arc<Client>, cx: &mut AppContext) {
265 cx.set_global(Assistant::default());
266 AssistantSettings::register(cx);
267
268 cx.spawn(|mut cx| {
269 let client = client.clone();
270 async move {
271 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
272 let semantic_index = SemanticIndex::new(
273 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
274 Arc::new(embedding_provider),
275 &mut cx,
276 )
277 .await?;
278 cx.update(|cx| cx.set_global(semantic_index))
279 }
280 })
281 .detach();
282
283 prompt_library::init(cx);
284 completion_provider::init(client.clone(), cx);
285 assistant_slash_command::init(cx);
286 register_slash_commands(cx);
287 assistant_panel::init(cx);
288 inline_assistant::init(client.telemetry().clone(), cx);
289 RustdocStore::init_global(cx);
290
291 CommandPaletteFilter::update_global(cx, |filter, _cx| {
292 filter.hide_namespace(Assistant::NAMESPACE);
293 });
294 Assistant::update_global(cx, |assistant, cx| {
295 let settings = AssistantSettings::get_global(cx);
296
297 assistant.set_enabled(settings.enabled, cx);
298 });
299 cx.observe_global::<SettingsStore>(|cx| {
300 Assistant::update_global(cx, |assistant, cx| {
301 let settings = AssistantSettings::get_global(cx);
302 assistant.set_enabled(settings.enabled, cx);
303 });
304 })
305 .detach();
306}
307
308fn register_slash_commands(cx: &mut AppContext) {
309 let slash_command_registry = SlashCommandRegistry::global(cx);
310 slash_command_registry.register_command(file_command::FileSlashCommand, true);
311 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
312 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
313 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
314 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
315 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
316 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
317 slash_command_registry.register_command(now_command::NowSlashCommand, true);
318 slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
319 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
320 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
321}
322
323#[cfg(test)]
324#[ctor::ctor]
325fn init_logger() {
326 if std::env::var("RUST_LOG").is_ok() {
327 env_logger::init();
328 }
329}