1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod search;
8mod slash_command;
9mod streaming_diff;
10
11pub use assistant_panel::AssistantPanel;
12
13use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
14use client::{proto, Client};
15use command_palette_hooks::CommandPaletteFilter;
16pub(crate) use completion_provider::*;
17use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
18pub(crate) use saved_conversation::*;
19use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
20use serde::{Deserialize, Serialize};
21use settings::{Settings, SettingsStore};
22use std::{
23 fmt::{self, Display},
24 sync::Arc,
25};
26use util::paths::EMBEDDINGS_DIR;
27
28actions!(
29 assistant,
30 [
31 Assist,
32 Split,
33 CycleMessageRole,
34 QuoteSelection,
35 ToggleFocus,
36 ResetKey,
37 InlineAssist,
38 InsertActivePrompt,
39 ToggleHistory,
40 ApplyEdit,
41 ConfirmCommand
42 ]
43);
44
45#[derive(
46 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
47)]
48struct MessageId(usize);
49
50#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
51#[serde(rename_all = "lowercase")]
52pub enum Role {
53 User,
54 Assistant,
55 System,
56}
57
58impl Role {
59 pub fn cycle(&mut self) {
60 *self = match self {
61 Role::User => Role::Assistant,
62 Role::Assistant => Role::System,
63 Role::System => Role::User,
64 }
65 }
66}
67
68impl Display for Role {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 Role::User => write!(f, "user"),
72 Role::Assistant => write!(f, "assistant"),
73 Role::System => write!(f, "system"),
74 }
75 }
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
79pub enum LanguageModel {
80 ZedDotDev(ZedDotDevModel),
81 OpenAi(OpenAiModel),
82 Anthropic(AnthropicModel),
83}
84
85impl Default for LanguageModel {
86 fn default() -> Self {
87 LanguageModel::ZedDotDev(ZedDotDevModel::default())
88 }
89}
90
91impl LanguageModel {
92 pub fn telemetry_id(&self) -> String {
93 match self {
94 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
95 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
96 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
97 }
98 }
99
100 pub fn display_name(&self) -> String {
101 match self {
102 LanguageModel::OpenAi(model) => model.display_name().into(),
103 LanguageModel::Anthropic(model) => model.display_name().into(),
104 LanguageModel::ZedDotDev(model) => model.display_name().into(),
105 }
106 }
107
108 pub fn max_token_count(&self) -> usize {
109 match self {
110 LanguageModel::OpenAi(model) => model.max_token_count(),
111 LanguageModel::Anthropic(model) => model.max_token_count(),
112 LanguageModel::ZedDotDev(model) => model.max_token_count(),
113 }
114 }
115
116 pub fn id(&self) -> &str {
117 match self {
118 LanguageModel::OpenAi(model) => model.id(),
119 LanguageModel::Anthropic(model) => model.id(),
120 LanguageModel::ZedDotDev(model) => model.id(),
121 }
122 }
123}
124
125#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
126pub struct LanguageModelRequestMessage {
127 pub role: Role,
128 pub content: String,
129}
130
131impl LanguageModelRequestMessage {
132 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
133 proto::LanguageModelRequestMessage {
134 role: match self.role {
135 Role::User => proto::LanguageModelRole::LanguageModelUser,
136 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
137 Role::System => proto::LanguageModelRole::LanguageModelSystem,
138 } as i32,
139 content: self.content.clone(),
140 tool_calls: Vec::new(),
141 tool_call_id: None,
142 }
143 }
144}
145
146#[derive(Debug, Default, Serialize)]
147pub struct LanguageModelRequest {
148 pub model: LanguageModel,
149 pub messages: Vec<LanguageModelRequestMessage>,
150 pub stop: Vec<String>,
151 pub temperature: f32,
152}
153
154impl LanguageModelRequest {
155 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
156 proto::CompleteWithLanguageModel {
157 model: self.model.id().to_string(),
158 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
159 stop: self.stop.clone(),
160 temperature: self.temperature,
161 tool_choice: None,
162 tools: Vec::new(),
163 }
164 }
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168pub struct LanguageModelResponseMessage {
169 pub role: Option<Role>,
170 pub content: Option<String>,
171}
172
173#[derive(Deserialize, Debug)]
174pub struct LanguageModelUsage {
175 pub prompt_tokens: u32,
176 pub completion_tokens: u32,
177 pub total_tokens: u32,
178}
179
180#[derive(Deserialize, Debug)]
181pub struct LanguageModelChoiceDelta {
182 pub index: u32,
183 pub delta: LanguageModelResponseMessage,
184 pub finish_reason: Option<String>,
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize)]
188struct MessageMetadata {
189 role: Role,
190 status: MessageStatus,
191}
192
193#[derive(Clone, Debug, Serialize, Deserialize)]
194enum MessageStatus {
195 Pending,
196 Done,
197 Error(SharedString),
198}
199
200/// The state pertaining to the Assistant.
201#[derive(Default)]
202struct Assistant {
203 /// Whether the Assistant is enabled.
204 enabled: bool,
205}
206
207impl Global for Assistant {}
208
209impl Assistant {
210 const NAMESPACE: &'static str = "assistant";
211
212 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
213 if self.enabled == enabled {
214 return;
215 }
216
217 self.enabled = enabled;
218
219 if !enabled {
220 CommandPaletteFilter::update_global(cx, |filter, _cx| {
221 filter.hide_namespace(Self::NAMESPACE);
222 });
223
224 return;
225 }
226
227 CommandPaletteFilter::update_global(cx, |filter, _cx| {
228 filter.show_namespace(Self::NAMESPACE);
229 });
230 }
231}
232
233pub fn init(client: Arc<Client>, cx: &mut AppContext) {
234 cx.set_global(Assistant::default());
235 AssistantSettings::register(cx);
236
237 cx.spawn(|mut cx| {
238 let client = client.clone();
239 async move {
240 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
241 let semantic_index = SemanticIndex::new(
242 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
243 Arc::new(embedding_provider),
244 &mut cx,
245 )
246 .await?;
247 cx.update(|cx| cx.set_global(semantic_index))
248 }
249 })
250 .detach();
251 completion_provider::init(client, cx);
252 assistant_slash_command::init(cx);
253 assistant_panel::init(cx);
254
255 CommandPaletteFilter::update_global(cx, |filter, _cx| {
256 filter.hide_namespace(Assistant::NAMESPACE);
257 });
258 Assistant::update_global(cx, |assistant, cx| {
259 let settings = AssistantSettings::get_global(cx);
260
261 assistant.set_enabled(settings.enabled, cx);
262 });
263 cx.observe_global::<SettingsStore>(|cx| {
264 Assistant::update_global(cx, |assistant, cx| {
265 let settings = AssistantSettings::get_global(cx);
266
267 assistant.set_enabled(settings.enabled, cx);
268 });
269 })
270 .detach();
271}
272
273#[cfg(test)]
274#[ctor::ctor]
275fn init_logger() {
276 if std::env::var("RUST_LOG").is_ok() {
277 env_logger::init();
278 }
279}