app.go

  1package app
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"sync"
 11	"time"
 12
 13	tea "github.com/charmbracelet/bubbletea/v2"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/format"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/llm/agent"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26)
 27
 28type App struct {
 29	Sessions    session.Service
 30	Messages    message.Service
 31	History     history.Service
 32	Permissions permission.Service
 33
 34	CoderAgent agent.Service
 35
 36	LSPClients map[string]*lsp.Client
 37
 38	clientsMutex sync.RWMutex
 39
 40	watcherCancelFuncs []context.CancelFunc
 41	cancelFuncsMutex   sync.Mutex
 42	lspWatcherWG       sync.WaitGroup
 43
 44	config *config.Config
 45
 46	serviceEventsWG *sync.WaitGroup
 47	eventsCtx       context.Context
 48	events          chan tea.Msg
 49	tuiWG           *sync.WaitGroup
 50
 51	// global context and cleanup functions
 52	globalCtx    context.Context
 53	cleanupFuncs []func()
 54}
 55
 56func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 57	q := db.New(conn)
 58	sessions := session.NewService(q)
 59	messages := message.NewService(q)
 60	files := history.NewService(q, conn)
 61
 62	app := &App{
 63		Sessions:    sessions,
 64		Messages:    messages,
 65		History:     files,
 66		Permissions: permission.NewPermissionService(cfg.WorkingDir()),
 67		LSPClients:  make(map[string]*lsp.Client),
 68
 69		globalCtx: ctx,
 70
 71		config: cfg,
 72
 73		events:          make(chan tea.Msg, 100),
 74		serviceEventsWG: &sync.WaitGroup{},
 75		tuiWG:           &sync.WaitGroup{},
 76	}
 77
 78	app.setupEvents()
 79
 80	// Initialize LSP clients in the background
 81	go app.initLSPClients(ctx)
 82
 83	// TODO: remove the concept of agent config most likely
 84	if cfg.IsConfigured() {
 85		if err := app.InitCoderAgent(); err != nil {
 86			return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
 87		}
 88	} else {
 89		slog.Warn("No agent configuration found")
 90	}
 91	return app, nil
 92}
 93
 94// RunNonInteractive handles the execution flow when a prompt is provided via CLI flag.
 95func (a *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
 96	slog.Info("Running in non-interactive mode")
 97
 98	ctx, cancel := context.WithCancel(ctx)
 99	defer cancel()
100
101	// Start spinner if not in quiet mode
102	var spinner *format.Spinner
103	if !quiet {
104		spinner = format.NewSpinner(ctx, cancel, "Generating")
105		spinner.Start()
106	}
107	// Helper function to stop spinner once
108	stopSpinner := func() {
109		if !quiet && spinner != nil {
110			spinner.Stop()
111			spinner = nil
112		}
113	}
114	defer stopSpinner()
115
116	const maxPromptLengthForTitle = 100
117	titlePrefix := "Non-interactive: "
118	var titleSuffix string
119
120	if len(prompt) > maxPromptLengthForTitle {
121		titleSuffix = prompt[:maxPromptLengthForTitle] + "..."
122	} else {
123		titleSuffix = prompt
124	}
125	title := titlePrefix + titleSuffix
126
127	sess, err := a.Sessions.Create(ctx, title)
128	if err != nil {
129		return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
130	}
131	slog.Info("Created session for non-interactive run", "session_id", sess.ID)
132
133	// Automatically approve all permission requests for this non-interactive session
134	a.Permissions.AutoApproveSession(sess.ID)
135
136	done, err := a.CoderAgent.Run(ctx, sess.ID, prompt)
137	if err != nil {
138		return fmt.Errorf("failed to start agent processing stream: %w", err)
139	}
140
141	messageEvents := a.Messages.Subscribe(ctx)
142	readBts := 0
143
144	for {
145		select {
146		case result := <-done:
147			stopSpinner()
148
149			if result.Error != nil {
150				if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
151					slog.Info("Agent processing cancelled", "session_id", sess.ID)
152					return nil
153				}
154				return fmt.Errorf("agent processing failed: %w", result.Error)
155			}
156
157			part := result.Message.Content().String()[readBts:]
158			fmt.Println(part)
159
160			slog.Info("Non-interactive run completed", "session_id", sess.ID)
161			return nil
162
163		case event := <-messageEvents:
164			msg := event.Payload
165			if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
166				stopSpinner()
167				part := msg.Content().String()[readBts:]
168				fmt.Print(part)
169				readBts += len(part)
170			}
171
172		case <-ctx.Done():
173			stopSpinner()
174			return ctx.Err()
175		}
176	}
177}
178
179func (app *App) UpdateAgentModel() error {
180	return app.CoderAgent.UpdateModel()
181}
182
183func (app *App) setupEvents() {
184	ctx, cancel := context.WithCancel(app.globalCtx)
185	app.eventsCtx = ctx
186	setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
187	setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
188	setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
189	setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
190	cleanupFunc := func() {
191		cancel()
192		app.serviceEventsWG.Wait()
193	}
194	app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
195}
196
197func setupSubscriber[T any](
198	ctx context.Context,
199	wg *sync.WaitGroup,
200	name string,
201	subscriber func(context.Context) <-chan pubsub.Event[T],
202	outputCh chan<- tea.Msg,
203) {
204	wg.Add(1)
205	go func() {
206		defer wg.Done()
207		subCh := subscriber(ctx)
208		for {
209			select {
210			case event, ok := <-subCh:
211				if !ok {
212					slog.Debug("subscription channel closed", "name", name)
213					return
214				}
215				var msg tea.Msg = event
216				select {
217				case outputCh <- msg:
218				case <-time.After(2 * time.Second):
219					slog.Warn("message dropped due to slow consumer", "name", name)
220				case <-ctx.Done():
221					slog.Debug("subscription cancelled", "name", name)
222					return
223				}
224			case <-ctx.Done():
225				slog.Debug("subscription cancelled", "name", name)
226				return
227			}
228		}
229	}()
230}
231
232func (app *App) InitCoderAgent() error {
233	coderAgentCfg := app.config.Agents["coder"]
234	if coderAgentCfg.ID == "" {
235		return fmt.Errorf("coder agent configuration is missing")
236	}
237	var err error
238	app.CoderAgent, err = agent.NewAgent(
239		coderAgentCfg,
240		app.Permissions,
241		app.Sessions,
242		app.Messages,
243		app.History,
244		app.LSPClients,
245	)
246	if err != nil {
247		slog.Error("Failed to create coder agent", "err", err)
248		return err
249	}
250	setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
251	return nil
252}
253
254func (app *App) Subscribe(program *tea.Program) {
255	defer log.RecoverPanic("app.Subscribe", func() {
256		slog.Info("TUI subscription panic - attempting graceful shutdown")
257		program.Quit()
258	})
259
260	app.tuiWG.Add(1)
261	tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
262	app.cleanupFuncs = append(app.cleanupFuncs, func() {
263		slog.Debug("Cancelling TUI message handler")
264		tuiCancel()
265		app.tuiWG.Wait()
266	})
267	defer app.tuiWG.Done()
268	for {
269		select {
270		case <-tuiCtx.Done():
271			slog.Debug("TUI message handler shutting down")
272			return
273		case msg, ok := <-app.events:
274			if !ok {
275				slog.Debug("TUI message channel closed")
276				return
277			}
278			program.Send(msg)
279		}
280	}
281}
282
283// Shutdown performs a clean shutdown of the application
284func (app *App) Shutdown() {
285	if app.CoderAgent != nil {
286		app.CoderAgent.CancelAll()
287	}
288	app.cancelFuncsMutex.Lock()
289	for _, cancel := range app.watcherCancelFuncs {
290		cancel()
291	}
292	app.cancelFuncsMutex.Unlock()
293	app.lspWatcherWG.Wait()
294
295	app.clientsMutex.RLock()
296	clients := make(map[string]*lsp.Client, len(app.LSPClients))
297	maps.Copy(clients, app.LSPClients)
298	app.clientsMutex.RUnlock()
299
300	for name, client := range clients {
301		shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
302		if err := client.Shutdown(shutdownCtx); err != nil {
303			slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
304		}
305		cancel()
306	}
307
308	for _, cleanup := range app.cleanupFuncs {
309		if cleanup != nil {
310			cleanup()
311		}
312	}
313}