1package workspace
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "strings"
8 "sync"
9 "time"
10
11 tea "charm.land/bubbletea/v2"
12 "github.com/charmbracelet/crush/internal/agent/notify"
13 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
14 "github.com/charmbracelet/crush/internal/client"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/log"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/oauth"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/proto"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// ClientWorkspace implements the Workspace interface by delegating all
28// operations to a remote server via the client SDK. It caches the
29// proto.Workspace returned at creation time and refreshes it after
30// config-mutating operations.
31type ClientWorkspace struct {
32 client *client.Client
33
34 mu sync.RWMutex
35 ws proto.Workspace
36}
37
38// NewClientWorkspace creates a new ClientWorkspace that proxies all
39// operations through the given client SDK. The ws parameter is the
40// proto.Workspace snapshot returned by the server at creation time.
41func NewClientWorkspace(c *client.Client, ws proto.Workspace) *ClientWorkspace {
42 if ws.Config != nil {
43 ws.Config.SetupAgents()
44 }
45 return &ClientWorkspace{
46 client: c,
47 ws: ws,
48 }
49}
50
51// refreshWorkspace re-fetches the workspace from the server, updating
52// the cached snapshot. Called after config-mutating operations.
53func (w *ClientWorkspace) refreshWorkspace() {
54 updated, err := w.client.GetWorkspace(context.Background(), w.ws.ID)
55 if err != nil {
56 slog.Error("Failed to refresh workspace", "error", err)
57 return
58 }
59 if updated.Config != nil {
60 updated.Config.SetupAgents()
61 }
62 w.mu.Lock()
63 w.ws = *updated
64 w.mu.Unlock()
65}
66
67// cached returns a snapshot of the cached workspace.
68func (w *ClientWorkspace) cached() proto.Workspace {
69 w.mu.RLock()
70 defer w.mu.RUnlock()
71 return w.ws
72}
73
74// workspaceID returns the cached workspace ID.
75func (w *ClientWorkspace) workspaceID() string {
76 return w.cached().ID
77}
78
79// -- Sessions --
80
81func (w *ClientWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) {
82 sess, err := w.client.CreateSession(ctx, w.workspaceID(), title)
83 if err != nil {
84 return session.Session{}, err
85 }
86 return *sess, nil
87}
88
89func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) {
90 sess, err := w.client.GetSession(ctx, w.workspaceID(), sessionID)
91 if err != nil {
92 return session.Session{}, err
93 }
94 return *sess, nil
95}
96
97func (w *ClientWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) {
98 return w.client.ListSessions(ctx, w.workspaceID())
99}
100
101func (w *ClientWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) {
102 saved, err := w.client.SaveSession(ctx, w.workspaceID(), sess)
103 if err != nil {
104 return session.Session{}, err
105 }
106 return *saved, nil
107}
108
109func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) error {
110 return w.client.DeleteSession(ctx, w.workspaceID(), sessionID)
111}
112
113func (w *ClientWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
114 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
115}
116
117func (w *ClientWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
118 parts := strings.Split(sessionID, "$$")
119 if len(parts) != 2 {
120 return "", "", false
121 }
122 return parts[0], parts[1], true
123}
124
125// -- Messages --
126
127func (w *ClientWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
128 return w.client.ListMessages(ctx, w.workspaceID(), sessionID)
129}
130
131func (w *ClientWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
132 return w.client.ListUserMessages(ctx, w.workspaceID(), sessionID)
133}
134
135func (w *ClientWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) {
136 return w.client.ListAllUserMessages(ctx, w.workspaceID())
137}
138
139// -- Agent --
140
141func (w *ClientWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error {
142 return w.client.SendMessage(ctx, w.workspaceID(), sessionID, prompt, attachments...)
143}
144
145func (w *ClientWorkspace) AgentCancel(sessionID string) {
146 _ = w.client.CancelAgentSession(context.Background(), w.workspaceID(), sessionID)
147}
148
149func (w *ClientWorkspace) AgentIsBusy() bool {
150 info, err := w.client.GetAgentInfo(context.Background(), w.workspaceID())
151 if err != nil {
152 return false
153 }
154 return info.IsBusy
155}
156
157func (w *ClientWorkspace) AgentIsSessionBusy(sessionID string) bool {
158 info, err := w.client.GetAgentSessionInfo(context.Background(), w.workspaceID(), sessionID)
159 if err != nil {
160 return false
161 }
162 return info.IsBusy
163}
164
165func (w *ClientWorkspace) AgentModel() AgentModel {
166 info, err := w.client.GetAgentInfo(context.Background(), w.workspaceID())
167 if err != nil {
168 return AgentModel{}
169 }
170 return AgentModel{
171 CatwalkCfg: info.Model,
172 ModelCfg: info.ModelCfg,
173 }
174}
175
176func (w *ClientWorkspace) AgentIsReady() bool {
177 info, err := w.client.GetAgentInfo(context.Background(), w.workspaceID())
178 if err != nil {
179 return false
180 }
181 return info.IsReady
182}
183
184func (w *ClientWorkspace) AgentQueuedPrompts(sessionID string) int {
185 count, err := w.client.GetAgentSessionQueuedPrompts(context.Background(), w.workspaceID(), sessionID)
186 if err != nil {
187 return 0
188 }
189 return count
190}
191
192func (w *ClientWorkspace) AgentQueuedPromptsList(sessionID string) []string {
193 prompts, err := w.client.GetAgentSessionQueuedPromptsList(context.Background(), w.workspaceID(), sessionID)
194 if err != nil {
195 return nil
196 }
197 return prompts
198}
199
200func (w *ClientWorkspace) AgentClearQueue(sessionID string) {
201 _ = w.client.ClearAgentSessionQueuedPrompts(context.Background(), w.workspaceID(), sessionID)
202}
203
204func (w *ClientWorkspace) AgentSummarize(ctx context.Context, sessionID string) error {
205 return w.client.AgentSummarizeSession(ctx, w.workspaceID(), sessionID)
206}
207
208func (w *ClientWorkspace) UpdateAgentModel(ctx context.Context) error {
209 return w.client.UpdateAgent(ctx, w.workspaceID())
210}
211
212func (w *ClientWorkspace) InitCoderAgent(ctx context.Context) error {
213 return w.client.InitiateAgentProcessing(ctx, w.workspaceID())
214}
215
216func (w *ClientWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel {
217 model, err := w.client.GetDefaultSmallModel(context.Background(), w.workspaceID(), providerID)
218 if err != nil {
219 return config.SelectedModel{}
220 }
221 return *model
222}
223
224// -- Permissions --
225
226func (w *ClientWorkspace) PermissionGrant(perm permission.PermissionRequest) {
227 _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{
228 Permission: proto.PermissionRequest{
229 ID: perm.ID,
230 SessionID: perm.SessionID,
231 ToolCallID: perm.ToolCallID,
232 ToolName: perm.ToolName,
233 Description: perm.Description,
234 Action: perm.Action,
235 Path: perm.Path,
236 Params: perm.Params,
237 },
238 Action: proto.PermissionAllowForSession,
239 })
240}
241
242func (w *ClientWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) {
243 _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{
244 Permission: proto.PermissionRequest{
245 ID: perm.ID,
246 SessionID: perm.SessionID,
247 ToolCallID: perm.ToolCallID,
248 ToolName: perm.ToolName,
249 Description: perm.Description,
250 Action: perm.Action,
251 Path: perm.Path,
252 Params: perm.Params,
253 },
254 Action: proto.PermissionAllow,
255 })
256}
257
258func (w *ClientWorkspace) PermissionDeny(perm permission.PermissionRequest) {
259 _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{
260 Permission: proto.PermissionRequest{
261 ID: perm.ID,
262 SessionID: perm.SessionID,
263 ToolCallID: perm.ToolCallID,
264 ToolName: perm.ToolName,
265 Description: perm.Description,
266 Action: perm.Action,
267 Path: perm.Path,
268 Params: perm.Params,
269 },
270 Action: proto.PermissionDeny,
271 })
272}
273
274func (w *ClientWorkspace) PermissionSkipRequests() bool {
275 skip, err := w.client.GetPermissionsSkipRequests(context.Background(), w.workspaceID())
276 if err != nil {
277 return false
278 }
279 return skip
280}
281
282func (w *ClientWorkspace) PermissionSetSkipRequests(skip bool) {
283 _ = w.client.SetPermissionsSkipRequests(context.Background(), w.workspaceID(), skip)
284}
285
286// -- FileTracker --
287
288func (w *ClientWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) {
289 _ = w.client.FileTrackerRecordRead(ctx, w.workspaceID(), sessionID, path)
290}
291
292func (w *ClientWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time {
293 t, err := w.client.FileTrackerLastReadTime(ctx, w.workspaceID(), sessionID, path)
294 if err != nil {
295 return time.Time{}
296 }
297 return t
298}
299
300func (w *ClientWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
301 return w.client.FileTrackerListReadFiles(ctx, w.workspaceID(), sessionID)
302}
303
304// -- History --
305
306func (w *ClientWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) {
307 return w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID)
308}
309
310// -- LSP --
311
312func (w *ClientWorkspace) LSPStart(ctx context.Context, path string) {
313 _ = w.client.LSPStart(ctx, w.workspaceID(), path)
314}
315
316func (w *ClientWorkspace) LSPStopAll(ctx context.Context) {
317 _ = w.client.LSPStopAll(ctx, w.workspaceID())
318}
319
320func (w *ClientWorkspace) LSPGetStates() map[string]LSPClientInfo {
321 states, err := w.client.GetLSPs(context.Background(), w.workspaceID())
322 if err != nil {
323 return nil
324 }
325 result := make(map[string]LSPClientInfo, len(states))
326 for k, v := range states {
327 result[k] = LSPClientInfo{
328 Name: v.Name,
329 State: v.State,
330 Error: v.Error,
331 DiagnosticCount: v.DiagnosticCount,
332 ConnectedAt: v.ConnectedAt,
333 }
334 }
335 return result
336}
337
338func (w *ClientWorkspace) LSPGetClient(_ string) (*lsp.Client, bool) {
339 return nil, false
340}
341
342// -- Config (read-only) --
343
344func (w *ClientWorkspace) Config() *config.Config {
345 return w.cached().Config
346}
347
348func (w *ClientWorkspace) WorkingDir() string {
349 return w.cached().Path
350}
351
352func (w *ClientWorkspace) Resolver() config.VariableResolver {
353 return config.IdentityResolver()
354}
355
356// -- Config mutations --
357
358func (w *ClientWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error {
359 err := w.client.UpdatePreferredModel(context.Background(), w.workspaceID(), scope, modelType, model)
360 if err == nil {
361 w.refreshWorkspace()
362 }
363 return err
364}
365
366func (w *ClientWorkspace) SetCompactMode(scope config.Scope, enabled bool) error {
367 err := w.client.SetCompactMode(context.Background(), w.workspaceID(), scope, enabled)
368 if err == nil {
369 w.refreshWorkspace()
370 }
371 return err
372}
373
374func (w *ClientWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error {
375 err := w.client.SetProviderAPIKey(context.Background(), w.workspaceID(), scope, providerID, apiKey)
376 if err == nil {
377 w.refreshWorkspace()
378 }
379 return err
380}
381
382func (w *ClientWorkspace) SetConfigField(scope config.Scope, key string, value any) error {
383 err := w.client.SetConfigField(context.Background(), w.workspaceID(), scope, key, value)
384 if err == nil {
385 w.refreshWorkspace()
386 }
387 return err
388}
389
390func (w *ClientWorkspace) RemoveConfigField(scope config.Scope, key string) error {
391 err := w.client.RemoveConfigField(context.Background(), w.workspaceID(), scope, key)
392 if err == nil {
393 w.refreshWorkspace()
394 }
395 return err
396}
397
398func (w *ClientWorkspace) ImportCopilot() (*oauth.Token, bool) {
399 token, ok, err := w.client.ImportCopilot(context.Background(), w.workspaceID())
400 if err != nil {
401 return nil, false
402 }
403 if ok {
404 w.refreshWorkspace()
405 }
406 return token, ok
407}
408
409func (w *ClientWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error {
410 err := w.client.RefreshOAuthToken(ctx, w.workspaceID(), scope, providerID)
411 if err == nil {
412 w.refreshWorkspace()
413 }
414 return err
415}
416
417// -- Project lifecycle --
418
419func (w *ClientWorkspace) ProjectNeedsInitialization() (bool, error) {
420 return w.client.ProjectNeedsInitialization(context.Background(), w.workspaceID())
421}
422
423func (w *ClientWorkspace) MarkProjectInitialized() error {
424 return w.client.MarkProjectInitialized(context.Background(), w.workspaceID())
425}
426
427func (w *ClientWorkspace) InitializePrompt() (string, error) {
428 return w.client.GetInitializePrompt(context.Background(), w.workspaceID())
429}
430
431// -- MCP operations --
432
433func (w *ClientWorkspace) MCPGetStates() map[string]mcp.ClientInfo {
434 states, err := w.client.MCPGetStates(context.Background(), w.workspaceID())
435 if err != nil {
436 return nil
437 }
438 result := make(map[string]mcp.ClientInfo, len(states))
439 for k, v := range states {
440 result[k] = mcp.ClientInfo{
441 Name: v.Name,
442 State: mcp.State(v.State),
443 Error: v.Error,
444 Counts: mcp.Counts{
445 Tools: v.ToolCount,
446 Prompts: v.PromptCount,
447 Resources: v.ResourceCount,
448 },
449 ConnectedAt: v.ConnectedAt,
450 }
451 }
452 return result
453}
454
455func (w *ClientWorkspace) MCPRefreshPrompts(ctx context.Context, name string) {
456 _ = w.client.MCPRefreshPrompts(ctx, w.workspaceID(), name)
457}
458
459func (w *ClientWorkspace) MCPRefreshResources(ctx context.Context, name string) {
460 _ = w.client.MCPRefreshResources(ctx, w.workspaceID(), name)
461}
462
463func (w *ClientWorkspace) RefreshMCPTools(ctx context.Context, name string) {
464 _ = w.client.RefreshMCPTools(ctx, w.workspaceID(), name)
465}
466
467func (w *ClientWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) {
468 contents, err := w.client.ReadMCPResource(ctx, w.workspaceID(), name, uri)
469 if err != nil {
470 return nil, err
471 }
472 result := make([]MCPResourceContents, len(contents))
473 for i, c := range contents {
474 result[i] = MCPResourceContents{
475 URI: c.URI,
476 MIMEType: c.MIMEType,
477 Text: c.Text,
478 Blob: c.Blob,
479 }
480 }
481 return result, nil
482}
483
484func (w *ClientWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
485 return w.client.GetMCPPrompt(context.Background(), w.workspaceID(), clientID, promptID, args)
486}
487
488// -- Lifecycle --
489
490func (w *ClientWorkspace) Subscribe(program *tea.Program) {
491 defer log.RecoverPanic("ClientWorkspace.Subscribe", func() {
492 slog.Info("TUI subscription panic: attempting graceful shutdown")
493 program.Quit()
494 })
495
496 evc, err := w.client.SubscribeEvents(context.Background(), w.workspaceID())
497 if err != nil {
498 slog.Error("Failed to subscribe to events", "error", err)
499 return
500 }
501
502 for ev := range evc {
503 translated := translateEvent(ev)
504 if translated != nil {
505 program.Send(translated)
506 }
507 }
508}
509
510func (w *ClientWorkspace) Shutdown() {
511 _ = w.client.DeleteWorkspace(context.Background(), w.workspaceID())
512}
513
514// translateEvent converts proto-typed SSE events into the domain types
515// that the TUI's Update() method expects.
516func translateEvent(ev any) tea.Msg {
517 switch e := ev.(type) {
518 case pubsub.Event[proto.LSPEvent]:
519 return pubsub.Event[LSPEvent]{
520 Type: e.Type,
521 Payload: LSPEvent{
522 Type: LSPEventType(e.Payload.Type),
523 Name: e.Payload.Name,
524 State: e.Payload.State,
525 Error: e.Payload.Error,
526 DiagnosticCount: e.Payload.DiagnosticCount,
527 },
528 }
529 case pubsub.Event[proto.MCPEvent]:
530 return pubsub.Event[mcp.Event]{
531 Type: e.Type,
532 Payload: mcp.Event{
533 Type: protoToMCPEventType(e.Payload.Type),
534 Name: e.Payload.Name,
535 State: mcp.State(e.Payload.State),
536 Error: e.Payload.Error,
537 Counts: mcp.Counts{
538 Tools: e.Payload.ToolCount,
539 Prompts: e.Payload.PromptCount,
540 Resources: e.Payload.ResourceCount,
541 },
542 },
543 }
544 case pubsub.Event[proto.PermissionRequest]:
545 return pubsub.Event[permission.PermissionRequest]{
546 Type: e.Type,
547 Payload: permission.PermissionRequest{
548 ID: e.Payload.ID,
549 SessionID: e.Payload.SessionID,
550 ToolCallID: e.Payload.ToolCallID,
551 ToolName: e.Payload.ToolName,
552 Description: e.Payload.Description,
553 Action: e.Payload.Action,
554 Path: e.Payload.Path,
555 Params: e.Payload.Params,
556 },
557 }
558 case pubsub.Event[proto.PermissionNotification]:
559 return pubsub.Event[permission.PermissionNotification]{
560 Type: e.Type,
561 Payload: permission.PermissionNotification{
562 ToolCallID: e.Payload.ToolCallID,
563 Granted: e.Payload.Granted,
564 Denied: e.Payload.Denied,
565 },
566 }
567 case pubsub.Event[proto.Message]:
568 return pubsub.Event[message.Message]{
569 Type: e.Type,
570 Payload: protoToMessage(e.Payload),
571 }
572 case pubsub.Event[proto.Session]:
573 return pubsub.Event[session.Session]{
574 Type: e.Type,
575 Payload: protoToSession(e.Payload),
576 }
577 case pubsub.Event[proto.File]:
578 return pubsub.Event[history.File]{
579 Type: e.Type,
580 Payload: protoToFile(e.Payload),
581 }
582 case pubsub.Event[proto.AgentEvent]:
583 return pubsub.Event[notify.Notification]{
584 Type: e.Type,
585 Payload: notify.Notification{
586 SessionID: e.Payload.SessionID,
587 SessionTitle: e.Payload.SessionTitle,
588 Type: notify.Type(e.Payload.Type),
589 },
590 }
591 default:
592 slog.Warn("Unknown event type in translateEvent", "type", fmt.Sprintf("%T", ev))
593 return nil
594 }
595}
596
597func protoToMCPEventType(t proto.MCPEventType) mcp.EventType {
598 switch t {
599 case proto.MCPEventStateChanged:
600 return mcp.EventStateChanged
601 case proto.MCPEventToolsListChanged:
602 return mcp.EventToolsListChanged
603 case proto.MCPEventPromptsListChanged:
604 return mcp.EventPromptsListChanged
605 case proto.MCPEventResourcesListChanged:
606 return mcp.EventResourcesListChanged
607 default:
608 return mcp.EventStateChanged
609 }
610}
611
612func protoToSession(s proto.Session) session.Session {
613 return session.Session{
614 ID: s.ID,
615 ParentSessionID: s.ParentSessionID,
616 Title: s.Title,
617 SummaryMessageID: s.SummaryMessageID,
618 MessageCount: s.MessageCount,
619 PromptTokens: s.PromptTokens,
620 CompletionTokens: s.CompletionTokens,
621 Cost: s.Cost,
622 CreatedAt: s.CreatedAt,
623 UpdatedAt: s.UpdatedAt,
624 }
625}
626
627func protoToFile(f proto.File) history.File {
628 return history.File{
629 ID: f.ID,
630 SessionID: f.SessionID,
631 Path: f.Path,
632 Content: f.Content,
633 Version: f.Version,
634 CreatedAt: f.CreatedAt,
635 UpdatedAt: f.UpdatedAt,
636 }
637}
638
639func protoToMessage(m proto.Message) message.Message {
640 msg := message.Message{
641 ID: m.ID,
642 SessionID: m.SessionID,
643 Role: message.MessageRole(m.Role),
644 Model: m.Model,
645 Provider: m.Provider,
646 CreatedAt: m.CreatedAt,
647 UpdatedAt: m.UpdatedAt,
648 }
649
650 for _, p := range m.Parts {
651 switch v := p.(type) {
652 case proto.TextContent:
653 msg.Parts = append(msg.Parts, message.TextContent{Text: v.Text})
654 case proto.ReasoningContent:
655 msg.Parts = append(msg.Parts, message.ReasoningContent{
656 Thinking: v.Thinking,
657 Signature: v.Signature,
658 StartedAt: v.StartedAt,
659 FinishedAt: v.FinishedAt,
660 })
661 case proto.ToolCall:
662 msg.Parts = append(msg.Parts, message.ToolCall{
663 ID: v.ID,
664 Name: v.Name,
665 Input: v.Input,
666 Finished: v.Finished,
667 })
668 case proto.ToolResult:
669 msg.Parts = append(msg.Parts, message.ToolResult{
670 ToolCallID: v.ToolCallID,
671 Name: v.Name,
672 Content: v.Content,
673 IsError: v.IsError,
674 })
675 case proto.Finish:
676 msg.Parts = append(msg.Parts, message.Finish{
677 Reason: message.FinishReason(v.Reason),
678 Time: v.Time,
679 Message: v.Message,
680 Details: v.Details,
681 })
682 case proto.ImageURLContent:
683 msg.Parts = append(msg.Parts, message.ImageURLContent{URL: v.URL, Detail: v.Detail})
684 case proto.BinaryContent:
685 msg.Parts = append(msg.Parts, message.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data})
686 }
687 }
688
689 return msg
690}