1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod streaming_diff;
8
9mod embedded_scope;
10
11pub use assistant_panel::AssistantPanel;
12use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
13use chrono::{DateTime, Local};
14use client::{proto, Client};
15use command_palette_hooks::CommandPaletteFilter;
16pub(crate) use completion_provider::*;
17use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
18pub(crate) use saved_conversation::*;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::{
22 fmt::{self, Display},
23 sync::Arc,
24};
25
26actions!(
27 assistant,
28 [
29 NewConversation,
30 Assist,
31 Split,
32 CycleMessageRole,
33 QuoteSelection,
34 ToggleFocus,
35 ResetKey,
36 InlineAssist,
37 ToggleIncludeConversation,
38 ]
39);
40
41#[derive(
42 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
43)]
44struct MessageId(usize);
45
46#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
47#[serde(rename_all = "lowercase")]
48pub enum Role {
49 User,
50 Assistant,
51 System,
52}
53
54impl Role {
55 pub fn cycle(&mut self) {
56 *self = match self {
57 Role::User => Role::Assistant,
58 Role::Assistant => Role::System,
59 Role::System => Role::User,
60 }
61 }
62}
63
64impl Display for Role {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Role::User => write!(f, "user"),
68 Role::Assistant => write!(f, "assistant"),
69 Role::System => write!(f, "system"),
70 }
71 }
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
75pub enum LanguageModel {
76 ZedDotDev(ZedDotDevModel),
77 OpenAi(OpenAiModel),
78}
79
80impl Default for LanguageModel {
81 fn default() -> Self {
82 LanguageModel::ZedDotDev(ZedDotDevModel::default())
83 }
84}
85
86impl LanguageModel {
87 pub fn telemetry_id(&self) -> String {
88 match self {
89 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
90 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
91 }
92 }
93
94 pub fn display_name(&self) -> String {
95 match self {
96 LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
97 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
98 }
99 }
100
101 pub fn max_token_count(&self) -> usize {
102 match self {
103 LanguageModel::OpenAi(model) => model.max_token_count(),
104 LanguageModel::ZedDotDev(model) => model.max_token_count(),
105 }
106 }
107
108 pub fn id(&self) -> &str {
109 match self {
110 LanguageModel::OpenAi(model) => model.id(),
111 LanguageModel::ZedDotDev(model) => model.id(),
112 }
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct LanguageModelRequestMessage {
118 pub role: Role,
119 pub content: String,
120}
121
122impl LanguageModelRequestMessage {
123 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
124 proto::LanguageModelRequestMessage {
125 role: match self.role {
126 Role::User => proto::LanguageModelRole::LanguageModelUser,
127 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
128 Role::System => proto::LanguageModelRole::LanguageModelSystem,
129 } as i32,
130 content: self.content.clone(),
131 tool_calls: Vec::new(),
132 tool_call_id: None,
133 }
134 }
135}
136
137#[derive(Debug, Default, Serialize)]
138pub struct LanguageModelRequest {
139 pub model: LanguageModel,
140 pub messages: Vec<LanguageModelRequestMessage>,
141 pub stop: Vec<String>,
142 pub temperature: f32,
143}
144
145impl LanguageModelRequest {
146 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
147 proto::CompleteWithLanguageModel {
148 model: self.model.id().to_string(),
149 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
150 stop: self.stop.clone(),
151 temperature: self.temperature,
152 tool_choice: None,
153 tools: Vec::new(),
154 }
155 }
156}
157
158#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
159pub struct LanguageModelResponseMessage {
160 pub role: Option<Role>,
161 pub content: Option<String>,
162}
163
164#[derive(Deserialize, Debug)]
165pub struct LanguageModelUsage {
166 pub prompt_tokens: u32,
167 pub completion_tokens: u32,
168 pub total_tokens: u32,
169}
170
171#[derive(Deserialize, Debug)]
172pub struct LanguageModelChoiceDelta {
173 pub index: u32,
174 pub delta: LanguageModelResponseMessage,
175 pub finish_reason: Option<String>,
176}
177
178#[derive(Clone, Debug, Serialize, Deserialize)]
179struct MessageMetadata {
180 role: Role,
181 sent_at: DateTime<Local>,
182 status: MessageStatus,
183}
184
185#[derive(Clone, Debug, Serialize, Deserialize)]
186enum MessageStatus {
187 Pending,
188 Done,
189 Error(SharedString),
190}
191
192/// The state pertaining to the Assistant.
193#[derive(Default)]
194struct Assistant {
195 /// Whether the Assistant is enabled.
196 enabled: bool,
197}
198
199impl Global for Assistant {}
200
201impl Assistant {
202 const NAMESPACE: &'static str = "assistant";
203
204 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
205 if self.enabled == enabled {
206 return;
207 }
208
209 self.enabled = enabled;
210
211 if !enabled {
212 CommandPaletteFilter::update_global(cx, |filter, _cx| {
213 filter.hide_namespace(Self::NAMESPACE);
214 });
215
216 return;
217 }
218
219 CommandPaletteFilter::update_global(cx, |filter, _cx| {
220 filter.show_namespace(Self::NAMESPACE);
221 });
222 }
223}
224
225pub fn init(client: Arc<Client>, cx: &mut AppContext) {
226 cx.set_global(Assistant::default());
227 AssistantSettings::register(cx);
228 completion_provider::init(client, cx);
229 assistant_panel::init(cx);
230
231 CommandPaletteFilter::update_global(cx, |filter, _cx| {
232 filter.hide_namespace(Assistant::NAMESPACE);
233 });
234 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
235 let settings = AssistantSettings::get_global(cx);
236
237 assistant.set_enabled(settings.enabled, cx);
238 });
239 cx.observe_global::<SettingsStore>(|cx| {
240 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
241 let settings = AssistantSettings::get_global(cx);
242
243 assistant.set_enabled(settings.enabled, cx);
244 });
245 })
246 .detach();
247}
248
249#[cfg(test)]
250#[ctor::ctor]
251fn init_logger() {
252 if std::env::var("RUST_LOG").is_ok() {
253 env_logger::init();
254 }
255}