.gitignore 🔗
@@ -41,6 +41,6 @@ debug.log
.env
.env.local
-.termai
+.opencode
internal/assets/diff/index.mjs
Kujtim Hoxha created
.gitignore | 2
.opencode.json | 0
cmd/git/main.go | 4
cmd/root.go | 253 +++++++++++++++++++------
internal/app/app.go | 76 +++++++
internal/app/lsp.go | 108 +++++++++++
internal/app/services.go | 64 ------
internal/config/config.go | 20 -
internal/history/file.go | 73 +++---
internal/llm/agent/agent-tool.go | 10
internal/llm/agent/agent.go | 53 ++--
internal/llm/agent/coder.go | 5
internal/llm/agent/task.go | 3
internal/message/message.go | 46 ++--
internal/session/session.go | 44 ++--
internal/tui/components/chat/messages.go | 5
internal/tui/components/repl/editor.go | 4
internal/tui/components/repl/messages.go | 7
internal/tui/components/repl/sessions.go | 4
internal/tui/page/chat.go | 8
internal/tui/tui.go | 6
21 files changed, 514 insertions(+), 281 deletions(-)
@@ -41,6 +41,6 @@ debug.log
.env
.env.local
-.termai
+.opencode
internal/assets/diff/index.mjs
@@ -1,4 +0,0 @@
-package main
-
-func main() {
-}
@@ -2,9 +2,10 @@ package cmd
import (
"context"
- "log/slog"
+ "fmt"
"os"
"sync"
+ "time"
tea "github.com/charmbracelet/bubbletea"
"github.com/kujtimiihoxha/termai/internal/app"
@@ -13,6 +14,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/logging"
+ "github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/tui"
zone "github.com/lrstanley/bubblezone"
"github.com/spf13/cobra"
@@ -23,111 +25,229 @@ var rootCmd = &cobra.Command{
Short: "A terminal ai assistant",
Long: `A terminal ai assistant`,
RunE: func(cmd *cobra.Command, args []string) error {
+ // If the help flag is set, show the help message
if cmd.Flag("help").Changed {
cmd.Help()
return nil
}
+
+ // Load the config
debug, _ := cmd.Flags().GetBool("debug")
- err := config.Load(debug)
+ cwd, _ := cmd.Flags().GetString("cwd")
+ if cwd != "" {
+ err := os.Chdir(cwd)
+ if err != nil {
+ return fmt.Errorf("failed to change directory: %v", err)
+ }
+ }
+ if cwd == "" {
+ c, err := os.Getwd()
+ if err != nil {
+ return fmt.Errorf("failed to get current working directory: %v", err)
+ }
+ cwd = c
+ }
+ _, err := config.Load(cwd, debug)
if err != nil {
return err
}
- cfg := config.Get()
- defaultLevel := slog.LevelInfo
- if cfg.Debug {
- defaultLevel = slog.LevelDebug
- }
- logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
- Level: defaultLevel,
- }))
- slog.SetDefault(logger)
err = assets.WriteAssets()
if err != nil {
- return err
+ logging.Error("Error writing assets: %v", err)
}
+ // Connect DB, this will also run migrations
conn, err := db.Connect()
if err != nil {
return err
}
- ctx := context.Background()
+
+ // Create main context for the application
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
app := app.New(ctx, conn)
- logging.Info("Starting termai...")
+
+ // Set up the TUI
zone.NewGlobal()
- tui := tea.NewProgram(
+ program := tea.NewProgram(
tui.New(app),
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
)
- logging.Info("Setting up subscriptions...")
- ch, unsub := setupSubscriptions(app)
- defer unsub()
+ // Initialize MCP tools in the background
+ initMCPTools(ctx, app)
+
+ // Setup the subscriptions, this will send services events to the TUI
+ ch, cancelSubs := setupSubscriptions(app)
+
+ // Create a context for the TUI message handler
+ tuiCtx, tuiCancel := context.WithCancel(ctx)
+ var tuiWg sync.WaitGroup
+ tuiWg.Add(1)
+
+ // Set up message handling for the TUI
go func() {
- // Set this up once
- agent.GetMcpTools(ctx, app.Permissions)
- for msg := range ch {
- tui.Send(msg)
+ defer tuiWg.Done()
+ defer func() {
+ if r := recover(); r != nil {
+ logging.Error("Panic in TUI message handling: %v", r)
+ attemptTUIRecovery(program)
+ }
+ }()
+
+ for {
+ select {
+ case <-tuiCtx.Done():
+ logging.Info("TUI message handler shutting down")
+ return
+ case msg, ok := <-ch:
+ if !ok {
+ logging.Info("TUI message channel closed")
+ return
+ }
+ program.Send(msg)
+ }
}
}()
- if _, err := tui.Run(); err != nil {
- return err
+
+ // Cleanup function for when the program exits
+ cleanup := func() {
+ // Shutdown the app
+ app.Shutdown()
+
+ // Cancel subscriptions first
+ cancelSubs()
+
+ // Then cancel TUI message handler
+ tuiCancel()
+
+ // Wait for TUI message handler to finish
+ tuiWg.Wait()
+
+ logging.Info("All goroutines cleaned up")
+ }
+
+ // Run the TUI
+ result, err := program.Run()
+ cleanup()
+
+ if err != nil {
+ logging.Error("TUI error: %v", err)
+ return fmt.Errorf("TUI error: %v", err)
}
+
+ logging.Info("TUI exited with result: %v", result)
return nil
},
}
-func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
- ch := make(chan tea.Msg)
- wg := sync.WaitGroup{}
- ctx, cancel := context.WithCancel(app.Context)
- {
- sub := logging.Subscribe(ctx)
- wg.Add(1)
- go func() {
- for ev := range sub {
- ch <- ev
+// attemptTUIRecovery tries to recover the TUI after a panic
+func attemptTUIRecovery(program *tea.Program) {
+ logging.Info("Attempting to recover TUI after panic")
+
+ // We could try to restart the TUI or gracefully exit
+ // For now, we'll just quit the program to avoid further issues
+ program.Quit()
+}
+
+func initMCPTools(ctx context.Context, app *app.App) {
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ logging.Error("Panic in MCP goroutine: %v", r)
}
- wg.Done()
}()
- }
- {
- sub := app.Sessions.Subscribe(ctx)
- wg.Add(1)
- go func() {
- for ev := range sub {
- ch <- ev
+
+ // Create a context with timeout for the initial MCP tools fetch
+ ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+
+ // Set this up once with proper error handling
+ agent.GetMcpTools(ctxWithTimeout, app.Permissions)
+ logging.Info("MCP message handling goroutine exiting")
+ }()
+}
+
+func setupSubscriber[T any](
+ ctx context.Context,
+ wg *sync.WaitGroup,
+ name string,
+ subscriber func(context.Context) <-chan pubsub.Event[T],
+ outputCh chan<- tea.Msg,
+) {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer func() {
+ if r := recover(); r != nil {
+ logging.Error("Panic in %s subscription goroutine: %v", name, r)
}
- wg.Done()
}()
- }
- {
- sub := app.Messages.Subscribe(ctx)
- wg.Add(1)
- go func() {
- for ev := range sub {
- ch <- ev
+
+ for {
+ select {
+ case event, ok := <-subscriber(ctx):
+ if !ok {
+ logging.Info("%s subscription channel closed", name)
+ return
+ }
+
+ // Convert generic event to tea.Msg if needed
+ var msg tea.Msg = event
+
+ // Non-blocking send with timeout to prevent deadlocks
+ select {
+ case outputCh <- msg:
+ case <-time.After(500 * time.Millisecond):
+ logging.Warn("%s message dropped due to slow consumer", name)
+ case <-ctx.Done():
+ logging.Info("%s subscription cancelled", name)
+ return
+ }
+ case <-ctx.Done():
+ logging.Info("%s subscription cancelled", name)
+ return
}
- wg.Done()
- }()
- }
- {
- sub := app.Permissions.Subscribe(ctx)
- wg.Add(1)
+ }
+ }()
+}
+
+func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
+ ch := make(chan tea.Msg, 100)
+ // Add a buffer to prevent blocking
+ wg := sync.WaitGroup{}
+ ctx, cancel := context.WithCancel(context.Background())
+ // Setup each subscription using the helper
+ setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
+ setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
+ setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
+ setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
+
+ // Return channel and a cleanup function
+ cleanupFunc := func() {
+ logging.Info("Cancelling all subscriptions")
+ cancel() // Signal all goroutines to stop
+
+ // Wait with a timeout for all goroutines to complete
+ waitCh := make(chan struct{})
go func() {
- for ev := range sub {
- ch <- ev
- }
- wg.Done()
+ wg.Wait()
+ close(waitCh)
}()
+
+ select {
+ case <-waitCh:
+ logging.Info("All subscription goroutines completed successfully")
+ case <-time.After(5 * time.Second):
+ logging.Warn("Timed out waiting for some subscription goroutines to complete")
+ }
+
+ close(ch) // Safe to close after all writers are done or timed out
}
- return ch, func() {
- cancel()
- wg.Wait()
- close(ch)
- }
+ return ch, cleanupFunc
}
func Execute() {
@@ -139,5 +259,6 @@ func Execute() {
func init() {
rootCmd.Flags().BoolP("help", "h", false, "Help")
- rootCmd.Flags().BoolP("debug", "d", false, "Help")
+ rootCmd.Flags().BoolP("debug", "d", false, "Debug")
+ rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
}
@@ -0,0 +1,76 @@
+package app
+
+import (
+ "context"
+ "database/sql"
+ "maps"
+ "sync"
+ "time"
+
+ "github.com/kujtimiihoxha/termai/internal/db"
+ "github.com/kujtimiihoxha/termai/internal/history"
+ "github.com/kujtimiihoxha/termai/internal/logging"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/kujtimiihoxha/termai/internal/session"
+)
+
+type App struct {
+ Sessions session.Service
+ Messages message.Service
+ Files history.Service
+ Permissions permission.Service
+
+ LSPClients map[string]*lsp.Client
+
+ clientsMutex sync.RWMutex
+
+ watcherCancelFuncs []context.CancelFunc
+ cancelFuncsMutex sync.Mutex
+ watcherWG sync.WaitGroup
+}
+
+func New(ctx context.Context, conn *sql.DB) *App {
+ q := db.New(conn)
+ sessions := session.NewService(q)
+ messages := message.NewService(q)
+ files := history.NewService(q)
+
+ app := &App{
+ Sessions: sessions,
+ Messages: messages,
+ Files: files,
+ Permissions: permission.NewPermissionService(),
+ LSPClients: make(map[string]*lsp.Client),
+ }
+
+ app.initLSPClients(ctx)
+
+ return app
+}
+
+// Shutdown performs a clean shutdown of the application
+func (app *App) Shutdown() {
+ // Cancel all watcher goroutines
+ app.cancelFuncsMutex.Lock()
+ for _, cancel := range app.watcherCancelFuncs {
+ cancel()
+ }
+ app.cancelFuncsMutex.Unlock()
+ app.watcherWG.Wait()
+
+ // Perform additional cleanup for LSP clients
+ app.clientsMutex.RLock()
+ clients := make(map[string]*lsp.Client, len(app.LSPClients))
+ maps.Copy(clients, app.LSPClients)
+ app.clientsMutex.RUnlock()
+
+ for name, client := range clients {
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ if err := client.Shutdown(shutdownCtx); err != nil {
+ logging.Error("Failed to shutdown LSP client", "name", name, "error", err)
+ }
+ cancel()
+ }
+}
@@ -0,0 +1,108 @@
+package app
+
+import (
+ "context"
+ "time"
+
+ "github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/logging"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/lsp/watcher"
+)
+
+func (app *App) initLSPClients(ctx context.Context) {
+ cfg := config.Get()
+
+ // Initialize LSP clients
+ for name, clientConfig := range cfg.LSP {
+ app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
+ }
+}
+
+// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher
+func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) {
+ // Create a specific context for initialization with a timeout
+ initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer initCancel()
+
+ // Create the LSP client
+ lspClient, err := lsp.NewClient(initCtx, command, args...)
+ if err != nil {
+ logging.Error("Failed to create LSP client for", name, err)
+ return
+ }
+
+ // Initialize with the initialization context
+ _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
+ if err != nil {
+ logging.Error("Initialize failed", "name", name, "error", err)
+ // Clean up the client to prevent resource leaks
+ lspClient.Close()
+ return
+ }
+
+ // Create a child context that can be canceled when the app is shutting down
+ watchCtx, cancelFunc := context.WithCancel(ctx)
+ workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
+
+ // Store the cancel function to be called during cleanup
+ app.cancelFuncsMutex.Lock()
+ app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc)
+ app.cancelFuncsMutex.Unlock()
+
+ // Add the watcher to a WaitGroup to track active goroutines
+ app.watcherWG.Add(1)
+
+ // Add to map with mutex protection before starting goroutine
+ app.clientsMutex.Lock()
+ app.LSPClients[name] = lspClient
+ app.clientsMutex.Unlock()
+
+ go app.runWorkspaceWatcher(watchCtx, name, workspaceWatcher)
+}
+
+// runWorkspaceWatcher executes the workspace watcher for an LSP client
+func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) {
+ defer app.watcherWG.Done()
+ defer func() {
+ if r := recover(); r != nil {
+ logging.Error("LSP client crashed", "client", name, "panic", r)
+
+ // Try to restart the client
+ app.restartLSPClient(ctx, name)
+ }
+ }()
+
+ workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
+ logging.Info("Workspace watcher stopped", "client", name)
+}
+
+// restartLSPClient attempts to restart a crashed or failed LSP client
+func (app *App) restartLSPClient(ctx context.Context, name string) {
+ // Get the original configuration
+ cfg := config.Get()
+ clientConfig, exists := cfg.LSP[name]
+ if !exists {
+ logging.Error("Cannot restart client, configuration not found", "client", name)
+ return
+ }
+
+ // Clean up the old client if it exists
+ app.clientsMutex.Lock()
+ oldClient, exists := app.LSPClients[name]
+ if exists {
+ delete(app.LSPClients, name) // Remove from map before potentially slow shutdown
+ }
+ app.clientsMutex.Unlock()
+
+ if exists && oldClient != nil {
+ // Try to shut it down gracefully, but don't block on errors
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ _ = oldClient.Shutdown(shutdownCtx)
+ cancel()
+ }
+
+ // Create a new client using the shared function
+ app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
+ logging.Info("Successfully restarted LSP client", "client", name)
+}
@@ -1,64 +0,0 @@
-package app
-
-import (
- "context"
- "database/sql"
-
- "github.com/kujtimiihoxha/termai/internal/config"
- "github.com/kujtimiihoxha/termai/internal/db"
- "github.com/kujtimiihoxha/termai/internal/history"
- "github.com/kujtimiihoxha/termai/internal/logging"
- "github.com/kujtimiihoxha/termai/internal/lsp"
- "github.com/kujtimiihoxha/termai/internal/lsp/watcher"
- "github.com/kujtimiihoxha/termai/internal/message"
- "github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/kujtimiihoxha/termai/internal/session"
-)
-
-type App struct {
- Context context.Context
-
- Sessions session.Service
- Messages message.Service
- Files history.Service
- Permissions permission.Service
-
- LSPClients map[string]*lsp.Client
-}
-
-func New(ctx context.Context, conn *sql.DB) *App {
- cfg := config.Get()
- logging.Info("Debug mode enabled")
-
- q := db.New(conn)
- sessions := session.NewService(ctx, q)
- messages := message.NewService(ctx, q)
- files := history.NewService(ctx, q)
-
- app := &App{
- Context: ctx,
- Sessions: sessions,
- Messages: messages,
- Files: files,
- Permissions: permission.NewPermissionService(),
- LSPClients: make(map[string]*lsp.Client),
- }
-
- for name, client := range cfg.LSP {
- lspClient, err := lsp.NewClient(ctx, client.Command, client.Args...)
- workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
- if err != nil {
- logging.Error("Failed to create LSP client for", name, err)
- continue
- }
-
- _, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory())
- if err != nil {
- logging.Error("Initialize failed", "error", err)
- continue
- }
- go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
- app.LSPClients[name] = lspClient
- }
- return app
-}
@@ -83,9 +83,9 @@ var cfg *Config
// Load initializes the configuration from environment variables and config files.
// If debug is true, debug mode is enabled and log level is set to debug.
// It returns an error if configuration loading fails.
-func Load(workingDir string, debug bool) error {
+func Load(workingDir string, debug bool) (*Config, error) {
if cfg != nil {
- return nil
+ return cfg, nil
}
cfg = &Config{
@@ -101,7 +101,7 @@ func Load(workingDir string, debug bool) error {
// Read global config
if err := readConfig(viper.ReadInConfig()); err != nil {
- return err
+ return cfg, err
}
// Load and merge local config
@@ -109,7 +109,7 @@ func Load(workingDir string, debug bool) error {
// Apply configuration to the struct
if err := viper.Unmarshal(cfg); err != nil {
- return err
+ return cfg, fmt.Errorf("failed to unmarshal config: %w", err)
}
applyDefaultValues()
@@ -123,7 +123,7 @@ func Load(workingDir string, debug bool) error {
Level: defaultLevel,
}))
slog.SetDefault(logger)
- return nil
+ return cfg, nil
}
// configureViper sets up viper's configuration paths and environment variables.
@@ -237,7 +237,7 @@ func readConfig(err error) error {
return nil
}
- return err
+ return fmt.Errorf("failed to read config: %w", err)
}
// mergeLocalConfig loads and merges configuration from the local directory.
@@ -264,14 +264,6 @@ func applyDefaultValues() {
}
}
-// setWorkingDirectory stores the current working directory in the configuration.
-func setWorkingDirectory() {
- workdir, err := os.Getwd()
- if err == nil {
- viper.Set("wd", workdir)
- }
-}
-
// Get returns the current configuration.
// It's safe to call this function multiple times.
func Get() *Config {
@@ -27,45 +27,43 @@ type File struct {
type Service interface {
pubsub.Suscriber[File]
- Create(sessionID, path, content string) (File, error)
- CreateVersion(sessionID, path, content string) (File, error)
- Get(id string) (File, error)
- GetByPathAndSession(path, sessionID string) (File, error)
- ListBySession(sessionID string) ([]File, error)
- ListLatestSessionFiles(sessionID string) ([]File, error)
- Update(file File) (File, error)
- Delete(id string) error
- DeleteSessionFiles(sessionID string) error
+ Create(ctx context.Context, sessionID, path, content string) (File, error)
+ CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
+ Get(ctx context.Context, id string) (File, error)
+ GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
+ ListBySession(ctx context.Context, sessionID string) ([]File, error)
+ ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
+ Update(ctx context.Context, file File) (File, error)
+ Delete(ctx context.Context, id string) error
+ DeleteSessionFiles(ctx context.Context, sessionID string) error
}
type service struct {
*pubsub.Broker[File]
- q db.Querier
- ctx context.Context
+ q db.Querier
}
-func NewService(ctx context.Context, q db.Querier) Service {
+func NewService(q db.Querier) Service {
return &service{
Broker: pubsub.NewBroker[File](),
q: q,
- ctx: ctx,
}
}
-func (s *service) Create(sessionID, path, content string) (File, error) {
- return s.createWithVersion(sessionID, path, content, InitialVersion)
+func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
+ return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
}
-func (s *service) CreateVersion(sessionID, path, content string) (File, error) {
+func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
// Get the latest version for this path
- files, err := s.q.ListFilesByPath(s.ctx, path)
+ files, err := s.q.ListFilesByPath(ctx, path)
if err != nil {
return File{}, err
}
if len(files) == 0 {
// No previous versions, create initial
- return s.Create(sessionID, path, content)
+ return s.Create(ctx, sessionID, path, content)
}
// Get the latest version
@@ -89,11 +87,11 @@ func (s *service) CreateVersion(sessionID, path, content string) (File, error) {
nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
}
- return s.createWithVersion(sessionID, path, content, nextVersion)
+ return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
}
-func (s *service) createWithVersion(sessionID, path, content, version string) (File, error) {
- dbFile, err := s.q.CreateFile(s.ctx, db.CreateFileParams{
+func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) {
+ dbFile, err := s.q.CreateFile(ctx, db.CreateFileParams{
ID: uuid.New().String(),
SessionID: sessionID,
Path: path,
@@ -108,16 +106,16 @@ func (s *service) createWithVersion(sessionID, path, content, version string) (F
return file, nil
}
-func (s *service) Get(id string) (File, error) {
- dbFile, err := s.q.GetFile(s.ctx, id)
+func (s *service) Get(ctx context.Context, id string) (File, error) {
+ dbFile, err := s.q.GetFile(ctx, id)
if err != nil {
return File{}, err
}
return s.fromDBItem(dbFile), nil
}
-func (s *service) GetByPathAndSession(path, sessionID string) (File, error) {
- dbFile, err := s.q.GetFileByPathAndSession(s.ctx, db.GetFileByPathAndSessionParams{
+func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
+ dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
Path: path,
SessionID: sessionID,
})
@@ -127,8 +125,8 @@ func (s *service) GetByPathAndSession(path, sessionID string) (File, error) {
return s.fromDBItem(dbFile), nil
}
-func (s *service) ListBySession(sessionID string) ([]File, error) {
- dbFiles, err := s.q.ListFilesBySession(s.ctx, sessionID)
+func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
+ dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
if err != nil {
return nil, err
}
@@ -139,8 +137,8 @@ func (s *service) ListBySession(sessionID string) ([]File, error) {
return files, nil
}
-func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) {
- dbFiles, err := s.q.ListLatestSessionFiles(s.ctx, sessionID)
+func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
+ dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
if err != nil {
return nil, err
}
@@ -151,8 +149,8 @@ func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) {
return files, nil
}
-func (s *service) Update(file File) (File, error) {
- dbFile, err := s.q.UpdateFile(s.ctx, db.UpdateFileParams{
+func (s *service) Update(ctx context.Context, file File) (File, error) {
+ dbFile, err := s.q.UpdateFile(ctx, db.UpdateFileParams{
ID: file.ID,
Content: file.Content,
Version: file.Version,
@@ -165,12 +163,12 @@ func (s *service) Update(file File) (File, error) {
return updatedFile, nil
}
-func (s *service) Delete(id string) error {
- file, err := s.Get(id)
+func (s *service) Delete(ctx context.Context, id string) error {
+ file, err := s.Get(ctx, id)
if err != nil {
return err
}
- err = s.q.DeleteFile(s.ctx, id)
+ err = s.q.DeleteFile(ctx, id)
if err != nil {
return err
}
@@ -178,13 +176,13 @@ func (s *service) Delete(id string) error {
return nil
}
-func (s *service) DeleteSessionFiles(sessionID string) error {
- files, err := s.ListBySession(sessionID)
+func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
+ files, err := s.ListBySession(ctx, sessionID)
if err != nil {
return err
}
for _, file := range files {
- err = s.Delete(file.ID)
+ err = s.Delete(ctx, file.ID)
if err != nil {
return err
}
@@ -203,4 +201,3 @@ func (s *service) fromDBItem(item db.File) File {
UpdatedAt: item.UpdatedAt,
}
}
-
@@ -51,7 +51,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil
}
- session, err := b.app.Sessions.CreateTaskSession(call.ID, b.parentSessionID, "New Agent Session")
+ session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session")
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
}
@@ -61,7 +61,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
}
- messages, err := b.app.Messages.List(session.ID)
+ messages, err := b.app.Messages.List(ctx, session.ID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil
}
@@ -74,11 +74,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("no assistant message found"), nil
}
- updatedSession, err := b.app.Sessions.Get(session.ID)
+ updatedSession, err := b.app.Sessions.Get(ctx, session.ID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
- parentSession, err := b.app.Sessions.Get(b.parentSessionID)
+ parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
@@ -87,7 +87,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
parentSession.PromptTokens += updatedSession.PromptTokens
parentSession.CompletionTokens += updatedSession.CompletionTokens
- _, err = b.app.Sessions.Save(parentSession)
+ _, err = b.app.Sessions.Save(ctx, parentSession)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
@@ -48,7 +48,7 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
return
}
- session, err := c.Sessions.Get(sessionID)
+ session, err := c.Sessions.Get(ctx, sessionID)
if err != nil {
return
}
@@ -56,12 +56,12 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
session.Title = response.Content
session.Title = strings.TrimSpace(session.Title)
session.Title = strings.ReplaceAll(session.Title, "\n", " ")
- c.Sessions.Save(session)
+ c.Sessions.Save(ctx, session)
}
}
-func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
- session, err := c.Sessions.Get(sessionID)
+func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+ session, err := c.Sessions.Get(ctx, sessionID)
if err != nil {
return err
}
@@ -75,11 +75,12 @@ func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.
session.CompletionTokens += usage.OutputTokens
session.PromptTokens += usage.InputTokens
- _, err = c.Sessions.Save(session)
+ _, err = c.Sessions.Save(ctx, session)
return err
}
func (c *agent) processEvent(
+ ctx context.Context,
sessionID string,
assistantMsg *message.Message,
event provider.ProviderEvent,
@@ -87,10 +88,10 @@ func (c *agent) processEvent(
switch event.Type {
case provider.EventThinkingDelta:
assistantMsg.AppendReasoningContent(event.Content)
- return c.Messages.Update(*assistantMsg)
+ return c.Messages.Update(ctx, *assistantMsg)
case provider.EventContentDelta:
assistantMsg.AppendContent(event.Content)
- return c.Messages.Update(*assistantMsg)
+ return c.Messages.Update(ctx, *assistantMsg)
case provider.EventError:
if errors.Is(event.Error, context.Canceled) {
return nil
@@ -105,11 +106,11 @@ func (c *agent) processEvent(
case provider.EventComplete:
assistantMsg.SetToolCalls(event.Response.ToolCalls)
assistantMsg.AddFinish(event.Response.FinishReason)
- err := c.Messages.Update(*assistantMsg)
+ err := c.Messages.Update(ctx, *assistantMsg)
if err != nil {
return err
}
- return c.TrackUsage(sessionID, c.model, event.Response.Usage)
+ return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage)
}
return nil
@@ -237,7 +238,7 @@ func (c *agent) handleToolExecution(
for _, toolResult := range toolResults {
parts = append(parts, toolResult)
}
- msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
+ msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
})
@@ -247,7 +248,7 @@ func (c *agent) handleToolExecution(
func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- messages, err := c.Messages.List(sessionID)
+ messages, err := c.Messages.List(ctx, sessionID)
if err != nil {
return err
}
@@ -256,7 +257,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
go c.handleTitleGeneration(ctx, sessionID, content)
}
- userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
+ userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{
@@ -272,7 +273,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
for {
select {
case <-ctx.Done():
- assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
+ assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
})
@@ -280,7 +281,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
return err
}
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
default:
// Continue processing
@@ -289,7 +290,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
if err != nil {
if errors.Is(err, context.Canceled) {
- assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
+ assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
})
@@ -297,13 +298,13 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
return err
}
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
}
return err
}
- assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
+ assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
Model: c.model.ID,
@@ -314,22 +315,22 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
for event := range eventChan {
- err = c.processEvent(sessionID, &assistantMsg, event)
+ err = c.processEvent(ctx, sessionID, &assistantMsg, event)
if err != nil {
if errors.Is(err, context.Canceled) {
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
}
assistantMsg.AddFinish("error:" + err.Error())
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return err
}
select {
case <-ctx.Done():
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
default:
}
@@ -339,7 +340,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
select {
case <-ctx.Done():
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
default:
// Continue processing
@@ -349,13 +350,13 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
if err != nil {
if errors.Is(err, context.Canceled) {
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
}
return err
}
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
if len(assistantMsg.ToolCalls()) == 0 {
break
@@ -370,7 +371,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
select {
case <-ctx.Done():
assistantMsg.AddFinish("canceled")
- c.Messages.Update(assistantMsg)
+ c.Messages.Update(ctx, assistantMsg)
return context.Canceled
default:
// Continue processing
@@ -383,7 +384,7 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
maxTokens := config.Get().Model.CoderMaxTokens
providerConfig, ok := config.Get().Providers[model.Provider]
- if !ok || !providerConfig.Enabled {
+ if !ok || providerConfig.Disabled {
return nil, nil, errors.New("provider is not enabled")
}
var agentProvider provider.Provider
@@ -40,12 +40,13 @@ func NewCoderAgent(app *app.App) (Agent, error) {
return nil, errors.New("model not supported")
}
- agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
+ ctx := context.Background()
+ agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
if err != nil {
return nil, err
}
- otherTools := GetMcpTools(app.Context, app.Permissions)
+ otherTools := GetMcpTools(ctx, app.Permissions)
if len(app.LSPClients) > 0 {
otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
}
@@ -24,7 +24,8 @@ func NewTaskAgent(app *app.App) (Agent, error) {
return nil, errors.New("model not supported")
}
- agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
+ ctx := context.Background()
+ agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
if err != nil {
return nil, err
}
@@ -20,34 +20,32 @@ type CreateMessageParams struct {
type Service interface {
pubsub.Suscriber[Message]
- Create(sessionID string, params CreateMessageParams) (Message, error)
- Update(message Message) error
- Get(id string) (Message, error)
- List(sessionID string) ([]Message, error)
- Delete(id string) error
- DeleteSessionMessages(sessionID string) error
+ Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
+ Update(ctx context.Context, message Message) error
+ Get(ctx context.Context, id string) (Message, error)
+ List(ctx context.Context, sessionID string) ([]Message, error)
+ Delete(ctx context.Context, id string) error
+ DeleteSessionMessages(ctx context.Context, sessionID string) error
}
type service struct {
*pubsub.Broker[Message]
- q db.Querier
- ctx context.Context
+ q db.Querier
}
-func NewService(ctx context.Context, q db.Querier) Service {
+func NewService(q db.Querier) Service {
return &service{
Broker: pubsub.NewBroker[Message](),
q: q,
- ctx: ctx,
}
}
-func (s *service) Delete(id string) error {
- message, err := s.Get(id)
+func (s *service) Delete(ctx context.Context, id string) error {
+ message, err := s.Get(ctx, id)
if err != nil {
return err
}
- err = s.q.DeleteMessage(s.ctx, message.ID)
+ err = s.q.DeleteMessage(ctx, message.ID)
if err != nil {
return err
}
@@ -55,7 +53,7 @@ func (s *service) Delete(id string) error {
return nil
}
-func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
+func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
if params.Role != Assistant {
params.Parts = append(params.Parts, Finish{
Reason: "stop",
@@ -66,7 +64,7 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message,
return Message{}, err
}
- dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
+ dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
Role: string(params.Role),
@@ -84,14 +82,14 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message,
return message, nil
}
-func (s *service) DeleteSessionMessages(sessionID string) error {
- messages, err := s.List(sessionID)
+func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
+ messages, err := s.List(ctx, sessionID)
if err != nil {
return err
}
for _, message := range messages {
if message.SessionID == sessionID {
- err = s.Delete(message.ID)
+ err = s.Delete(ctx, message.ID)
if err != nil {
return err
}
@@ -100,7 +98,7 @@ func (s *service) DeleteSessionMessages(sessionID string) error {
return nil
}
-func (s *service) Update(message Message) error {
+func (s *service) Update(ctx context.Context, message Message) error {
parts, err := marshallParts(message.Parts)
if err != nil {
return err
@@ -110,7 +108,7 @@ func (s *service) Update(message Message) error {
finishedAt.Int64 = f.Time
finishedAt.Valid = true
}
- err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
+ err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
ID: message.ID,
Parts: string(parts),
FinishedAt: finishedAt,
@@ -122,16 +120,16 @@ func (s *service) Update(message Message) error {
return nil
}
-func (s *service) Get(id string) (Message, error) {
- dbMessage, err := s.q.GetMessage(s.ctx, id)
+func (s *service) Get(ctx context.Context, id string) (Message, error) {
+ dbMessage, err := s.q.GetMessage(ctx, id)
if err != nil {
return Message{}, err
}
return s.fromDBItem(dbMessage)
}
-func (s *service) List(sessionID string) ([]Message, error) {
- dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
+func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
+ dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
if err != nil {
return nil, err
}
@@ -23,22 +23,21 @@ type Session struct {
type Service interface {
pubsub.Suscriber[Session]
- Create(title string) (Session, error)
- CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error)
- Get(id string) (Session, error)
- List() ([]Session, error)
- Save(session Session) (Session, error)
- Delete(id string) error
+ Create(ctx context.Context, title string) (Session, error)
+ CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
+ Get(ctx context.Context, id string) (Session, error)
+ List(ctx context.Context) ([]Session, error)
+ Save(ctx context.Context, session Session) (Session, error)
+ Delete(ctx context.Context, id string) error
}
type service struct {
*pubsub.Broker[Session]
- q db.Querier
- ctx context.Context
+ q db.Querier
}
-func (s *service) Create(title string) (Session, error) {
- dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{
+func (s *service) Create(ctx context.Context, title string) (Session, error) {
+ dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
ID: uuid.New().String(),
Title: title,
})
@@ -50,8 +49,8 @@ func (s *service) Create(title string) (Session, error) {
return session, nil
}
-func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) {
- dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{
+func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
+ dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
ID: toolCallID,
ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
Title: title,
@@ -64,12 +63,12 @@ func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) (
return session, nil
}
-func (s *service) Delete(id string) error {
- session, err := s.Get(id)
+func (s *service) Delete(ctx context.Context, id string) error {
+ session, err := s.Get(ctx, id)
if err != nil {
return err
}
- err = s.q.DeleteSession(s.ctx, session.ID)
+ err = s.q.DeleteSession(ctx, session.ID)
if err != nil {
return err
}
@@ -77,16 +76,16 @@ func (s *service) Delete(id string) error {
return nil
}
-func (s *service) Get(id string) (Session, error) {
- dbSession, err := s.q.GetSessionByID(s.ctx, id)
+func (s *service) Get(ctx context.Context, id string) (Session, error) {
+ dbSession, err := s.q.GetSessionByID(ctx, id)
if err != nil {
return Session{}, err
}
return s.fromDBItem(dbSession), nil
}
-func (s *service) Save(session Session) (Session, error) {
- dbSession, err := s.q.UpdateSession(s.ctx, db.UpdateSessionParams{
+func (s *service) Save(ctx context.Context, session Session) (Session, error) {
+ dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
ID: session.ID,
Title: session.Title,
PromptTokens: session.PromptTokens,
@@ -101,8 +100,8 @@ func (s *service) Save(session Session) (Session, error) {
return session, nil
}
-func (s *service) List() ([]Session, error) {
- dbSessions, err := s.q.ListSessions(s.ctx)
+func (s *service) List(ctx context.Context) ([]Session, error) {
+ dbSessions, err := s.q.ListSessions(ctx)
if err != nil {
return nil, err
}
@@ -127,11 +126,10 @@ func (s service) fromDBItem(item db.Session) Session {
}
}
-func NewService(ctx context.Context, q db.Querier) Service {
+func NewService(q db.Querier) Service {
broker := pubsub.NewBroker[Session]()
return &service{
broker,
q,
- ctx,
}
}
@@ -1,6 +1,7 @@
package chat
import (
+ "context"
"encoding/json"
"fmt"
"math"
@@ -324,7 +325,7 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s
innerToolCalls := make([]string, 0)
if toolCall.Name == agent.AgentToolName {
- messages, _ := m.app.Messages.List(toolCall.ID)
+ messages, _ := m.app.Messages.List(context.Background(), toolCall.ID)
toolCalls := make([]message.ToolCall, 0)
for _, v := range messages {
toolCalls = append(toolCalls, v.ToolCalls()...)
@@ -554,7 +555,7 @@ func (m *messagesCmp) GetSize() (int, int) {
func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
m.session = session
- messages, err := m.app.Messages.List(session.ID)
+ messages, err := m.app.Messages.List(context.Background(), session.ID)
if err != nil {
return util.ReportError(err)
}
@@ -160,7 +160,7 @@ func (m *editorCmp) Send() tea.Cmd {
return util.ReportWarn("Assistant is still working on the previous message")
}
- messages, err := m.app.Messages.List(m.sessionID)
+ messages, err := m.app.Messages.List(context.Background(), m.sessionID)
if err != nil {
return util.ReportError(err)
}
@@ -177,7 +177,7 @@ func (m *editorCmp) Send() tea.Cmd {
if len(content) == 0 {
return util.ReportWarn("Message is empty")
}
- ctx, cancel := context.WithCancel(m.app.Context)
+ ctx, cancel := context.WithCancel(context.Background())
m.cancelMessage = cancel
go func() {
defer cancel()
@@ -1,6 +1,7 @@
package repl
import (
+ "context"
"encoding/json"
"fmt"
"sort"
@@ -77,8 +78,8 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.session = msg.Payload
}
case SelectedSessionMsg:
- m.session, _ = m.app.Sessions.Get(msg.SessionID)
- m.messages, _ = m.app.Messages.List(m.session.ID)
+ m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID)
+ m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID)
m.renderView()
m.viewport.GotoBottom()
}
@@ -259,7 +260,7 @@ func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.
runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
allParts = append(allParts, leftPadding.Render(runningIndicator))
- taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
+ taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID)
for _, msg := range taskSessionMessages {
if msg.Role == message.Assistant {
for _, toolCall := range msg.ToolCalls() {
@@ -1,6 +1,7 @@
package repl
import (
+ "context"
"fmt"
"strings"
@@ -57,12 +58,13 @@ var sessionKeyMapValue = sessionsKeyMap{
}
func (i *sessionsCmp) Init() tea.Cmd {
- existing, err := i.app.Sessions.List()
+ existing, err := i.app.Sessions.List(context.Background())
if err != nil {
return util.ReportError(err)
}
if len(existing) == 0 || existing[0].MessageCount > 0 {
newSession, err := i.app.Sessions.Create(
+ context.Background(),
"New Session",
)
if err != nil {
@@ -1,6 +1,8 @@
package page
import (
+ "context"
+
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/kujtimiihoxha/termai/internal/app"
@@ -36,7 +38,7 @@ func (p *chatPage) Init() tea.Cmd {
p.layout.Init(),
}
- sessions, _ := p.app.Sessions.List()
+ sessions, _ := p.app.Sessions.List(context.Background())
if len(sessions) > 0 {
p.session = sessions[0]
cmd := p.setSidebar()
@@ -92,7 +94,7 @@ func (p *chatPage) clearSidebar() {
func (p *chatPage) sendMessage(text string) tea.Cmd {
var cmds []tea.Cmd
if p.session.ID == "" {
- session, err := p.app.Sessions.Create("New Session")
+ session, err := p.app.Sessions.Create(context.Background(), "New Session")
if err != nil {
return util.ReportError(err)
}
@@ -110,7 +112,7 @@ func (p *chatPage) sendMessage(text string) tea.Cmd {
return util.ReportError(err)
}
go func() {
- a.Generate(p.app.Context, p.session.ID, text)
+ a.Generate(context.Background(), p.session.ID, text)
}()
return tea.Batch(cmds...)
@@ -1,6 +1,8 @@
package tui
import (
+ "context"
+
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
@@ -184,7 +186,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
case key.Matches(msg, replKeyMap):
if a.currentPage == page.ReplPage {
- sessions, err := a.app.Sessions.List()
+ sessions, err := a.app.Sessions.List(context.Background())
if err != nil {
return a, util.CmdHandler(util.ReportError(err))
}
@@ -192,7 +194,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if lastSession.MessageCount == 0 {
return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID})
}
- s, err := a.app.Sessions.Create("New Session")
+ s, err := a.app.Sessions.Create(context.Background(), "New Session")
if err != nil {
return a, util.CmdHandler(util.ReportError(err))
}