From 4e4ade43b59eed979e5e35e6e0071e2cc9b0bbd3 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 13 Apr 2026 22:53:11 -0400 Subject: [PATCH] feat: remember provider/model info per-session When switching sessions use the last set provider and model if they exist. Thinking and reasoning settings are also preserved. This applies to both large and small models and targets both standalone and server-client implementations. This functionality can extend to additional agents when the time comes. --- internal/agent/agent.go | 10 +- internal/agent/coordinator.go | 4 +- internal/agent/tools/todos.go | 5 +- internal/app/resolve_session_test.go | 9 + internal/backend/session.go | 10 ++ internal/client/proto.go | 17 ++ internal/config/config.go | 21 +++ internal/config/selected_model_test.go | 143 +++++++++++++++ internal/db/db.go | 16 +- .../20260401000000_add_models_to_sessions.sql | 9 + internal/db/models.go | 1 + internal/db/querier.go | 1 + internal/db/sessions.sql.go | 47 ++++- internal/db/sql/sessions.sql | 6 + internal/proto/proto.go | 2 +- internal/proto/session.go | 44 +++-- internal/proto/session_test.go | 168 ++++++++++++++++++ internal/server/events.go | 25 +++ internal/server/proto.go | 34 ++++ internal/server/server.go | 1 + internal/session/session.go | 54 ++++++ internal/session/session_test.go | 151 ++++++++++++++++ internal/ui/model/ui.go | 48 +++++ internal/workspace/app_workspace.go | 4 + internal/workspace/client_workspace.go | 52 ++++++ internal/workspace/workspace.go | 1 + 26 files changed, 858 insertions(+), 25 deletions(-) create mode 100644 internal/config/selected_model_test.go create mode 100644 internal/db/migrations/20260401000000_add_models_to_sessions.sql create mode 100644 internal/proto/session_test.go create mode 100644 internal/session/session_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f25cbdc7849c9f9f3d55e34206faaca82834960c..53c17f8b28d0d4de6be03873acb797c598ff8ed0 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -423,7 +423,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy return getSessionErr } a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) - _, sessionErr := a.sessions.Save(ctx, updatedSession) + _, sessionErr := a.sessions.SaveWithModels(ctx, updatedSession, map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg, + config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg, + }) if sessionErr != nil { return sessionErr } @@ -723,7 +726,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan currentSession.SummaryMessageID = summaryMessage.ID currentSession.CompletionTokens = usage.OutputTokens currentSession.PromptTokens = 0 - _, err = a.sessions.Save(genCtx, currentSession) + _, err = a.sessions.SaveWithModels(genCtx, currentSession, map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg, + config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg, + }) return err } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 1130438a61217b3491bd21e388376f303631e9ec..60f65c8d4514f1a69e99a8466b920d7021d164af 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -476,7 +476,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep), tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls), tools.NewSourcegraphTool(nil), - tools.NewTodosTool(c.sessions), + tools.NewTodosTool(c.sessions, c.cfg), tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.skillTracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...), tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), ) @@ -1051,7 +1051,7 @@ func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionI parentSession.Cost += childSession.Cost - if _, err := c.sessions.Save(ctx, parentSession); err != nil { + if _, err := c.sessions.SaveWithModels(ctx, parentSession, c.cfg.Config().Models); err != nil { return fmt.Errorf("save parent session: %w", err) } diff --git a/internal/agent/tools/todos.go b/internal/agent/tools/todos.go index 2f69f7bf84581d9f0e0776d73660bef7ba34ba43..ea6dca52d54814a6e0e83b71f6c7d77b8c5502db 100644 --- a/internal/agent/tools/todos.go +++ b/internal/agent/tools/todos.go @@ -6,6 +6,7 @@ import ( "fmt" "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/session" ) @@ -33,7 +34,7 @@ type TodosResponseMetadata struct { Total int `json:"total"` } -func NewTodosTool(sessions session.Service) fantasy.AgentTool { +func NewTodosTool(sessions session.Service, cfg *config.ConfigStore) fantasy.AgentTool { return fantasy.NewAgentTool( TodosToolName, FirstLineDescription(todosDescription), @@ -96,7 +97,7 @@ func NewTodosTool(sessions session.Service) fantasy.AgentTool { } currentSession.Todos = todos - _, err = sessions.Save(ctx, currentSession) + _, err = sessions.SaveWithModels(ctx, currentSession, cfg.Config().Models) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("failed to save todos: %w", err) } diff --git a/internal/app/resolve_session_test.go b/internal/app/resolve_session_test.go index 9b0c7af736fa9637c095c7851da3460bacf737a2..7bc79063b71401c4b4886eb981a58aceb75ac921 100644 --- a/internal/app/resolve_session_test.go +++ b/internal/app/resolve_session_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/stretchr/testify/require" @@ -60,6 +61,14 @@ func (m *mockSessionService) Save(_ context.Context, s session.Session) (session return s, nil } +func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) { + return s, nil +} + +func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error { + return nil +} + func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error { return nil } diff --git a/internal/backend/session.go b/internal/backend/session.go index 10e21ed8932ccbc990a525785166517cd231595c..6465d9eedfeaf20b52a9460f8ac25a4aae91ed54 100644 --- a/internal/backend/session.go +++ b/internal/backend/session.go @@ -3,6 +3,7 @@ package backend import ( "context" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" @@ -105,6 +106,15 @@ func (b *Backend) DeleteSession(ctx context.Context, workspaceID, sessionID stri return ws.Sessions.Delete(ctx, sessionID) } +// UpdateSessionModels updates the models for a session in the given workspace. +func (b *Backend) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Sessions.UpdateSessionModels(ctx, sessionID, models) +} + // ListUserMessages returns user-role messages for a session. func (b *Backend) ListUserMessages(ctx context.Context, workspaceID, sessionID string) ([]message.Message, error) { ws, err := b.GetWorkspace(workspaceID) diff --git a/internal/client/proto.go b/internal/client/proto.go index f444cfc04f1e185a4551d6ac43eae4d99f3a02ba..ab125d7b1091ea68fb250aa1389c2d881706904e 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -585,6 +585,23 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string) return nil } +// UpdateSessionModels updates the models for a session. +func (c *Client) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/models", workspaceID, sessionID), nil, + jsonBody(struct { + Models map[config.SelectedModelType]config.SelectedModel `json:"models"` + }{Models: models}), + http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to update session models: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusNoContent { + return fmt.Errorf("failed to update session models: status code %d", rsp.StatusCode) + } + return nil +} + // ListUserMessages retrieves user-role messages for a session as proto types. func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil) diff --git a/internal/config/config.go b/internal/config/config.go index cee8ab8c4964bce56aa7c3ddffe98af115498776..e010a81b91ebb81ebb139372f1b8e3394cb72427 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -89,6 +89,27 @@ type SelectedModel struct { ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"` } +func (a SelectedModel) Equal(b SelectedModel) bool { + return a.Model == b.Model && + a.Provider == b.Provider && + a.ReasoningEffort == b.ReasoningEffort && + a.Think == b.Think && + a.MaxTokens == b.MaxTokens && + ptrEqual(a.Temperature, b.Temperature) && + ptrEqual(a.TopP, b.TopP) && + ptrEqual(a.TopK, b.TopK) && + ptrEqual(a.FrequencyPenalty, b.FrequencyPenalty) && + ptrEqual(a.PresencePenalty, b.PresencePenalty) && + maps.Equal(a.ProviderOptions, b.ProviderOptions) +} + +func ptrEqual[T comparable](a, b *T) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + type ProviderConfig struct { // The provider's id. ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"` diff --git a/internal/config/selected_model_test.go b/internal/config/selected_model_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1a96b92e876043caef8dfe101ddcf8cab347ecd8 --- /dev/null +++ b/internal/config/selected_model_test.go @@ -0,0 +1,143 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSelectedModelEqual(t *testing.T) { + t.Parallel() + + t.Run("equal models", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.True(t, a.Equal(b)) + }) + + t.Run("different model", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o-mini", Provider: "openai"} + require.False(t, a.Equal(b)) + }) + + t.Run("different provider", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "anthropic"} + require.False(t, a.Equal(b)) + }) + + t.Run("different reasoning effort", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "high"} + b := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "low"} + require.False(t, a.Equal(b)) + }) + + t.Run("different think", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: true} + b := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: false} + require.False(t, a.Equal(b)) + }) + + t.Run("different max tokens", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 4096} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 8192} + require.False(t, a.Equal(b)) + }) + + t.Run("both nil pointers", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.True(t, a.Equal(b)) + }) + + t.Run("one nil one non-nil pointer", func(t *testing.T) { + t.Parallel() + temp := 0.7 + a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.False(t, a.Equal(b)) + require.False(t, b.Equal(a)) + }) + + t.Run("both non-nil equal pointers", func(t *testing.T) { + t.Parallel() + temp := 0.7 + a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp} + require.True(t, a.Equal(b)) + }) + + t.Run("both non-nil different pointers", func(t *testing.T) { + t.Parallel() + t1 := 0.7 + t2 := 0.9 + a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t1} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t2} + require.False(t, a.Equal(b)) + }) + + t.Run("nil ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.True(t, a.Equal(b)) + }) + + t.Run("empty ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}} + require.True(t, a.Equal(b)) + }) + + t.Run("different ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "b"}} + require.False(t, a.Equal(b)) + }) + + t.Run("one nil one non-nil ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.False(t, a.Equal(b)) + }) +} + +func TestPtrEqual(t *testing.T) { + t.Parallel() + + t.Run("both nil", func(t *testing.T) { + t.Parallel() + require.True(t, ptrEqual[int](nil, nil)) + }) + + t.Run("one nil", func(t *testing.T) { + t.Parallel() + v := 42 + require.False(t, ptrEqual(&v, nil)) + require.False(t, ptrEqual(nil, &v)) + }) + + t.Run("both equal", func(t *testing.T) { + t.Parallel() + v := 42 + require.True(t, ptrEqual(&v, &v)) + }) + + t.Run("both different", func(t *testing.T) { + t.Parallel() + a := 42 + b := 43 + require.False(t, ptrEqual(&a, &b)) + }) +} diff --git a/internal/db/db.go b/internal/db/db.go index 6237bba5892b240ddd3c3f018926e2eb0ef4355a..45e84f6e2246109ee1869a27b473db7b193c3f74 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -11,10 +11,10 @@ import ( ) type DBTX interface { - ExecContext(context.Context, string, ...any) (sql.Result, error) + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...any) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...any) *sql.Row + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row } func New(db DBTX) *Queries { @@ -132,6 +132,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil { return nil, fmt.Errorf("error preparing query UpdateSession: %w", err) } + if q.updateSessionModelsStmt, err = db.PrepareContext(ctx, updateSessionModels); err != nil { + return nil, fmt.Errorf("error preparing query UpdateSessionModels: %w", err) + } if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil { return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err) } @@ -320,6 +323,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing updateSessionStmt: %w", cerr) } } + if q.updateSessionModelsStmt != nil { + if cerr := q.updateSessionModelsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing updateSessionModelsStmt: %w", cerr) + } + } if q.updateSessionTitleAndUsageStmt != nil { if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr) @@ -400,6 +408,7 @@ type Queries struct { renameSessionStmt *sql.Stmt updateMessageStmt *sql.Stmt updateSessionStmt *sql.Stmt + updateSessionModelsStmt *sql.Stmt updateSessionTitleAndUsageStmt *sql.Stmt } @@ -443,6 +452,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { renameSessionStmt: q.renameSessionStmt, updateMessageStmt: q.updateMessageStmt, updateSessionStmt: q.updateSessionStmt, + updateSessionModelsStmt: q.updateSessionModelsStmt, updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt, } } diff --git a/internal/db/migrations/20260401000000_add_models_to_sessions.sql b/internal/db/migrations/20260401000000_add_models_to_sessions.sql new file mode 100644 index 0000000000000000000000000000000000000000..639989eae3f6e4ef04791ab6aad422fbcc84b77d --- /dev/null +++ b/internal/db/migrations/20260401000000_add_models_to_sessions.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE sessions ADD COLUMN models TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE sessions DROP COLUMN models; +-- +goose StatementEnd \ No newline at end of file diff --git a/internal/db/models.go b/internal/db/models.go index 20034fb00a935bed7c4cfe4906dba66dd380ed64..c9220a4ab714cd2ad4a5d1d23688bcb5981b1150 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -49,4 +49,5 @@ type Session struct { CreatedAt int64 `json:"created_at"` SummaryMessageID sql.NullString `json:"summary_message_id"` Todos sql.NullString `json:"todos"` + Models sql.NullString `json:"models"` } diff --git a/internal/db/querier.go b/internal/db/querier.go index 9031505a3db825f2c21d83e005046323bde3a6c2..3886d6388c0123636e560e8d4a9b977993f5be39 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -45,6 +45,7 @@ type Querier interface { RenameSession(ctx context.Context, arg RenameSessionParams) error UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) + UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error } diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 685948e60e84ec4df66e4d5d1c9645a9ff1fb43f..4bf7ee8febbfd611f28f1b4d32e5c9f1690fd4e3 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -33,7 +33,7 @@ INSERT INTO sessions ( null, strftime('%s', 'now'), strftime('%s', 'now') -) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models ` type CreateSessionParams struct { @@ -69,6 +69,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ) return i, err } @@ -84,7 +85,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error { } const getLastSession = `-- name: GetLastSession :one -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models FROM sessions ORDER BY updated_at DESC LIMIT 1 @@ -105,12 +106,13 @@ func (q *Queries) GetLastSession(ctx context.Context) (Session, error) { &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ) return i, err } const getSessionByID = `-- name: GetSessionByID :one -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models FROM sessions WHERE id = ? LIMIT 1 ` @@ -130,12 +132,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ) return i, err } const listSessions = `-- name: ListSessions :many -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models FROM sessions WHERE parent_session_id is NULL ORDER BY updated_at DESC @@ -162,6 +165,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ); err != nil { return nil, err } @@ -203,7 +207,7 @@ SET cost = ?, todos = ? WHERE id = ? -RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models ` type UpdateSessionParams struct { @@ -239,6 +243,39 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, + ) + return i, err +} + +const updateSessionModels = `-- name: UpdateSessionModels :one +UPDATE sessions +SET models = ? +WHERE id = ? +RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models +` + +type UpdateSessionModelsParams struct { + Models sql.NullString `json:"models"` + ID string `json:"id"` +} + +func (q *Queries) UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error) { + row := q.queryRow(ctx, q.updateSessionModelsStmt, updateSessionModels, arg.Models, arg.ID) + var i Session + err := row.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + &i.Todos, + &i.Models, ) return i, err } diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index 44c1609ecfbc3867bea827088fcbcff6e718427b..75ab5fa5ae7efb50f1e32328fedd46adce77a9cd 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -69,6 +69,12 @@ SET title = ? WHERE id = ?; +-- name: UpdateSessionModels :one +UPDATE sessions +SET models = ? +WHERE id = ? +RETURNING *; + -- name: DeleteSession :exec DELETE FROM sessions WHERE id = ?; diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 9c84c6c8bf0c2f14da75933f8eebd7a36ca534ba..6d7ee4b150a132013a80c2070588e20ef8bdf456 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -56,7 +56,7 @@ type AgentSession struct { // IsZero checks if the AgentSession is zero-valued. func (a AgentSession) IsZero() bool { - return a == AgentSession{} + return a.ID == "" && !a.IsBusy } // PermissionAction represents an action taken on a permission request. diff --git a/internal/proto/session.go b/internal/proto/session.go index 846ac592017e6ce447c6c6a94535d9317adad7d8..fdf31b6c7746998db5c4c24a7820551357a5c33c 100644 --- a/internal/proto/session.go +++ b/internal/proto/session.go @@ -1,15 +1,39 @@ package proto +// SelectedModelType represents the type of model selection (large or small). +type SelectedModelType string + +const ( + SelectedModelTypeLarge SelectedModelType = "large" + SelectedModelTypeSmall SelectedModelType = "small" +) + +// SelectedModel represents a model selection with provider and configuration. +type SelectedModel struct { + Model string `json:"model"` + Provider string `json:"provider"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Think bool `json:"think,omitempty"` + MaxTokens int64 `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int64 `json:"top_k,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ProviderOptions map[string]any `json:"provider_options,omitempty"` +} + // Session represents a session in the proto layer. type Session struct { - ID string `json:"id"` - ParentSessionID string `json:"parent_session_id"` - Title string `json:"title"` - MessageCount int64 `json:"message_count"` - PromptTokens int64 `json:"prompt_tokens"` - CompletionTokens int64 `json:"completion_tokens"` - SummaryMessageID string `json:"summary_message_id"` - Cost float64 `json:"cost"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + SummaryMessageID string `json:"summary_message_id"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` } diff --git a/internal/proto/session_test.go b/internal/proto/session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7634be230f744e11836ed1e85545256f09c31bac --- /dev/null +++ b/internal/proto/session_test.go @@ -0,0 +1,168 @@ +package proto + +import ( + "encoding/json" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +func TestModelsRoundTrip(t *testing.T) { + t.Parallel() + + t.Run("nil models", func(t *testing.T) { + t.Parallel() + s := Session{ + ID: "test-id", + Title: "Test", + } + data, err := json.Marshal(s) + require.NoError(t, err) + var decoded Session + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Nil(t, decoded.Models) + }) + + t.Run("populated models", func(t *testing.T) { + t.Parallel() + temp := 0.7 + s := Session{ + ID: "test-id", + Title: "Test", + Models: map[SelectedModelType]SelectedModel{ + SelectedModelTypeLarge: { + Model: "gpt-4o", + Provider: "openai", + ReasoningEffort: "high", + Think: true, + MaxTokens: 4096, + Temperature: &temp, + ProviderOptions: map[string]any{"key": "value"}, + }, + SelectedModelTypeSmall: { + Model: "gpt-4o-mini", + Provider: "openai", + }, + }, + } + data, err := json.Marshal(s) + require.NoError(t, err) + var decoded Session + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Equal(t, "gpt-4o", decoded.Models[SelectedModelTypeLarge].Model) + require.Equal(t, "openai", decoded.Models[SelectedModelTypeLarge].Provider) + require.Equal(t, "high", decoded.Models[SelectedModelTypeLarge].ReasoningEffort) + require.True(t, decoded.Models[SelectedModelTypeLarge].Think) + require.Equal(t, int64(4096), decoded.Models[SelectedModelTypeLarge].MaxTokens) + require.NotNil(t, decoded.Models[SelectedModelTypeLarge].Temperature) + require.Equal(t, 0.7, *decoded.Models[SelectedModelTypeLarge].Temperature) + require.Equal(t, "gpt-4o-mini", decoded.Models[SelectedModelTypeSmall].Model) + }) + + t.Run("empty map models", func(t *testing.T) { + t.Parallel() + s := Session{ + ID: "test-id", + Title: "Test", + Models: map[SelectedModelType]SelectedModel{}, + } + data, err := json.Marshal(s) + require.NoError(t, err) + var decoded Session + require.NoError(t, json.Unmarshal(data, &decoded)) + // Empty map with omitempty is dropped during marshaling. + require.Nil(t, decoded.Models) + }) +} + +func TestProtoToDomainRoundTrip(t *testing.T) { + t.Parallel() + + t.Run("models through proto", func(t *testing.T) { + t.Parallel() + temp := 0.7 + domainModels := map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Model: "gpt-4o", + Provider: "openai", + ReasoningEffort: "high", + Think: true, + MaxTokens: 4096, + Temperature: &temp, + ProviderOptions: map[string]any{"key": "value"}, + }, + } + + // Domain → Proto + protoModels := convertModelsToProtoLocal(domainModels) + require.Equal(t, SelectedModelTypeLarge, SelectedModelType(config.SelectedModelTypeLarge)) + require.Equal(t, "gpt-4o", protoModels[SelectedModelTypeLarge].Model) + require.Equal(t, "openai", protoModels[SelectedModelTypeLarge].Provider) + + // Proto → Domain + result := convertModelsFromProtoLocal(protoModels) + require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model) + require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider) + require.Equal(t, "high", result[config.SelectedModelTypeLarge].ReasoningEffort) + require.True(t, result[config.SelectedModelTypeLarge].Think) + require.Equal(t, int64(4096), result[config.SelectedModelTypeLarge].MaxTokens) + require.NotNil(t, result[config.SelectedModelTypeLarge].Temperature) + require.Equal(t, 0.7, *result[config.SelectedModelTypeLarge].Temperature) + }) + + t.Run("nil models round-trip", func(t *testing.T) { + t.Parallel() + protoModels := convertModelsToProtoLocal(nil) + require.Nil(t, protoModels) + + domainModels := convertModelsFromProtoLocal(nil) + require.Nil(t, domainModels) + }) +} + +func convertModelsToProtoLocal(models map[config.SelectedModelType]config.SelectedModel) map[SelectedModelType]SelectedModel { + if models == nil { + return nil + } + result := make(map[SelectedModelType]SelectedModel, len(models)) + for k, v := range models { + result[SelectedModelType(k)] = SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result +} + +func convertModelsFromProtoLocal(models map[SelectedModelType]SelectedModel) map[config.SelectedModelType]config.SelectedModel { + if models == nil { + return nil + } + result := make(map[config.SelectedModelType]config.SelectedModel, len(models)) + for k, v := range models { + result[config.SelectedModelType(k)] = config.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result +} diff --git a/internal/server/events.go b/internal/server/events.go index 752311666bb6fcc2b1efde4d037711eaafaa0162..d3b1341a1e8166c2eb70d804f100a5b2fd08e614 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -8,6 +8,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" @@ -137,9 +138,33 @@ func sessionToProto(s session.Session) proto.Session { Cost: s.Cost, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, + Models: convertModelsToProto(s.Models), } } +func convertModelsToProto(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel { + if models == nil { + return nil + } + result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models)) + for k, v := range models { + result[proto.SelectedModelType(k)] = proto.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result +} + func fileToProto(f history.File) proto.File { return proto.File{ ID: f.ID, diff --git a/internal/server/proto.go b/internal/server/proto.go index af34131810a7af8c3672fe460198d25afe9ba064..2e9f0ddef2f7b3c072b751ba689be8d6e4517655 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" ) @@ -460,6 +461,39 @@ func (c *controllerV1) handleDeleteWorkspaceSession(w http.ResponseWriter, r *ht w.WriteHeader(http.StatusOK) } +// handlePostWorkspaceSessionModels updates the models for a session. +// +// @Summary Update session models +// @Tags sessions +// @Accept json +// @Param id path string true "Workspace ID" +// @Param sid path string true "Session ID" +// @Param body body map[config.SelectedModelType]config.SelectedModel true "Models" +// @Success 204 +// @Failure 400 {object} proto.Error +// @Failure 500 {object} proto.Error +// @Router /workspaces/{id}/sessions/{sid}/models [post] +func (c *controllerV1) handlePostWorkspaceSessionModels(w http.ResponseWriter, r *http.Request) { + workspaceID := r.PathValue("id") + sessionID := r.PathValue("sid") + + var req struct { + Models map[config.SelectedModelType]config.SelectedModel `json:"models"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.UpdateSessionModels(r.Context(), workspaceID, sessionID, req.Models); err != nil { + c.handleError(w, r, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} + // handleGetWorkspaceSessionUserMessages returns user messages for a session. // // @Summary Get user messages for session diff --git a/internal/server/server.go b/internal/server/server.go index 9ac4dba4c908050a0381b49258941d1b3a931970..275c79ec031a50876d7c72e8e2868b1f6c378950 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -122,6 +122,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession) mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession) mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession) + mux.HandleFunc("POST /v1/workspaces/{id}/sessions/{sid}/models", c.handlePostWorkspaceSessionModels) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages) diff --git a/internal/session/session.go b/internal/session/session.go index 66bd9f4c9a12916d02c6d22ed7d51f81d74efdfd..0271bf9a293095c259f114bce2de5e0d901e5ec1 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -8,6 +8,7 @@ import ( "log/slog" "strings" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/pubsub" @@ -56,6 +57,7 @@ type Session struct { SummaryMessageID string Cost float64 Todos []Todo + Models map[config.SelectedModelType]config.SelectedModel CreatedAt int64 UpdatedAt int64 } @@ -69,6 +71,8 @@ type Service interface { GetLast(ctx context.Context) (Session, error) List(ctx context.Context) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) + SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) + UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error Rename(ctx context.Context, id string, title string) error Delete(ctx context.Context, id string) error @@ -242,6 +246,10 @@ func (s service) fromDBItem(item db.Session) Session { if err != nil { slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err) } + models, err := unmarshalModels(item.Models.String) + if err != nil { + slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err) + } return Session{ ID: item.ID, ParentSessionID: item.ParentSessionID.String, @@ -252,6 +260,7 @@ func (s service) fromDBItem(item db.Session) Session { SummaryMessageID: item.SummaryMessageID.String, Cost: item.Cost, Todos: todos, + Models: models, CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, } @@ -279,6 +288,51 @@ func unmarshalTodos(data string) ([]Todo, error) { return todos, nil } +func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) { + if len(models) == 0 { + return "", nil + } + data, err := json.Marshal(models) + if err != nil { + return "", err + } + return string(data), nil +} + +func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) { + if data == "" { + return nil, nil + } + var models map[config.SelectedModelType]config.SelectedModel + if err := json.Unmarshal([]byte(data), &models); err != nil { + return nil, err + } + return models, nil +} + +func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error { + modelsJSON, err := marshalModels(models) + if err != nil { + return fmt.Errorf("failed to marshal models: %w", err) + } + _, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{ + Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""}, + ID: id, + }) + return err +} + +func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) { + saved, err := s.Save(ctx, session) + if err != nil { + return Session{}, err + } + if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil { + return Session{}, fmt.Errorf("failed to persist models: %w", err) + } + return saved, nil +} + func NewService(q *db.Queries, conn *sql.DB) Service { broker := pubsub.NewBroker[Session]() return &service{ diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c8f06795006081e7a05750fd6970b2d79076819f --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,151 @@ +package session + +import ( + "database/sql" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" + "github.com/stretchr/testify/require" +) + +func TestMarshalModels(t *testing.T) { + t.Parallel() + + t.Run("empty", func(t *testing.T) { + t.Parallel() + result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{}) + require.NoError(t, err) + require.Equal(t, "", result) + }) + + t.Run("nil", func(t *testing.T) { + t.Parallel() + result, err := marshalModels(nil) + require.NoError(t, err) + require.Equal(t, "", result) + }) + + t.Run("single entry", func(t *testing.T) { + t.Parallel() + models := map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Model: "claude-sonnet-4-20250514", + Provider: "anthropic", + }, + } + result, err := marshalModels(models) + require.NoError(t, err) + require.Contains(t, result, "claude-sonnet-4-20250514") + require.Contains(t, result, "anthropic") + }) + + t.Run("round-trip", func(t *testing.T) { + t.Parallel() + temp := 0.7 + topP := 0.9 + topK := int64(50) + freqPen := 0.1 + presPen := 0.2 + models := map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Model: "gpt-4o", + Provider: "openai", + ReasoningEffort: "high", + Think: true, + MaxTokens: 4096, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + FrequencyPenalty: &freqPen, + PresencePenalty: &presPen, + ProviderOptions: map[string]any{"key": "value"}, + }, + config.SelectedModelTypeSmall: { + Model: "gpt-4o-mini", + Provider: "openai", + }, + } + data, err := marshalModels(models) + require.NoError(t, err) + result, err := unmarshalModels(data) + require.NoError(t, err) + require.Equal(t, models, result) + }) +} + +func TestUnmarshalModels(t *testing.T) { + t.Parallel() + + t.Run("empty string", func(t *testing.T) { + t.Parallel() + result, err := unmarshalModels("") + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("valid JSON", func(t *testing.T) { + t.Parallel() + data := `{"large":{"model":"gpt-4o","provider":"openai"}}` + result, err := unmarshalModels(data) + require.NoError(t, err) + require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model) + require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider) + }) + + t.Run("invalid JSON", func(t *testing.T) { + t.Parallel() + _, err := unmarshalModels("{invalid}") + require.Error(t, err) + }) +} + +func TestFromDBItemWithModels(t *testing.T) { + t.Parallel() + + t.Run("null models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{Valid: false} + result := service{}.fromDBItem(item) + require.Nil(t, result.Models) + }) + + t.Run("empty models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{String: "", Valid: false} + result := service{}.fromDBItem(item) + require.Nil(t, result.Models) + }) + + t.Run("valid models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{ + String: `{"large":{"model":"gpt-4o","provider":"openai"}}`, + Valid: true, + } + result := service{}.fromDBItem(item) + require.NotNil(t, result.Models) + require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model) + }) + + t.Run("invalid JSON models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{ + String: "{invalid}", + Valid: true, + } + result := service{}.fromDBItem(item) + require.Nil(t, result.Models) + }) +} + +func testDBSession() db.Session { + return db.Session{ + ID: "test-id", + Title: "Test", + } +} diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index e5724987898a755f06a2b2b4643cce9ff94a3d83..15e8ec8cdb7b5f084cdda9e2d860a6933fa45efb 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -498,6 +498,29 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.setState(uiChat, m.focus) m.session = msg.session m.sessionFiles = msg.files + if msg.session.Models != nil { + cfg := m.com.Config() + anyChanged := false + for modelType, selectedModel := range msg.session.Models { + if current, ok := cfg.Models[modelType]; ok && current.Equal(selectedModel) { + continue + } + if cfg.GetModel(selectedModel.Provider, selectedModel.Model) != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, modelType, selectedModel); err != nil { + slog.Error("Failed to restore model", "type", modelType, "error", err) + } + anyChanged = true + } + } + if anyChanged { + cmds = append(cmds, func() tea.Msg { + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { + return util.ReportError(err) + } + return nil + }) + } + } cmds = append(cmds, m.startLSPs(msg.lspFilePaths())) msgs, err := m.com.Workspace.ListMessages(context.Background(), m.session.ID) if err != nil { @@ -1376,6 +1399,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { } return util.NewInfoMsg("Thinking mode " + status) }) + if m.session != nil { + cmds = append(cmds, m.saveSessionModels()) + } m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionToggleTransparentBackground: cmds = append(cmds, func() tea.Msg { @@ -1465,6 +1491,10 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { return util.NewInfoMsg(modelMsg) }) + if m.session != nil { + cmds = append(cmds, m.saveSessionModels()) + } + m.dialog.CloseDialog(dialog.APIKeyInputID) m.dialog.CloseDialog(dialog.OAuthID) m.dialog.CloseDialog(dialog.ModelsID) @@ -1505,6 +1535,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { m.com.Workspace.UpdateAgentModel(context.TODO()) return util.NewInfoMsg("Reasoning effort set to " + msg.Effort) }) + if m.session != nil { + cmds = append(cmds, m.saveSessionModels()) + } m.dialog.CloseDialog(dialog.ReasoningID) case dialog.ActionPermissionResponse: m.dialog.CloseDialog(dialog.PermissionsID) @@ -3652,3 +3685,18 @@ func renderLogo(t *styles.Styles, compact bool, width int) string { Width: width, }) } + +func (m *UI) saveSessionModels() tea.Cmd { + session := m.session + if session == nil { + return nil + } + models := m.com.Config().Models + return func() tea.Msg { + err := m.com.Workspace.UpdateSessionModels(context.Background(), session.ID, models) + if err != nil { + return util.ReportError(err) + } + return nil + } +} diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go index 57b1228e7eacb28a16141283ee2703a33511bd18..7a5c7a01aaefa96195e07c3aaaebfacf9b0faa91 100644 --- a/internal/workspace/app_workspace.go +++ b/internal/workspace/app_workspace.go @@ -59,6 +59,10 @@ func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) erro return w.app.Sessions.Delete(ctx, sessionID) } +func (w *AppWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + return w.app.Sessions.UpdateSessionModels(ctx, sessionID, models) +} + func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID) } diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index 7c4e1408882cc70859ea2ab05981461d262513e9..1dbda7aa3ee5b9a1c73cee2681ecd383dd9e98ed 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -119,6 +119,10 @@ func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) e return w.client.DeleteSession(ctx, w.workspaceID(), sessionID) } +func (w *ClientWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + return w.client.UpdateSessionModels(ctx, w.workspaceID(), sessionID, models) +} + func (w *ClientWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { return fmt.Sprintf("%s$$%s", messageID, toolCallID) } @@ -673,6 +677,7 @@ func protoToSession(s proto.Session) session.Session { Cost: s.Cost, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, + Models: convertModelsFromProto(s.Models), } } @@ -769,5 +774,52 @@ func sessionToProto(s session.Session) proto.Session { Cost: s.Cost, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, + Models: convertModelsToProtoClient(s.Models), + } +} + +func convertModelsFromProto(models map[proto.SelectedModelType]proto.SelectedModel) map[config.SelectedModelType]config.SelectedModel { + if models == nil { + return nil + } + result := make(map[config.SelectedModelType]config.SelectedModel, len(models)) + for k, v := range models { + result[config.SelectedModelType(k)] = config.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } } + return result +} + +func convertModelsToProtoClient(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel { + if models == nil { + return nil + } + result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models)) + for k, v := range models { + result[proto.SelectedModelType(k)] = proto.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result } diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go index 02c54c616f3251140bbee441451c3a4cb14845bd..486edf01bd8fce6f975234a1840237aba4416929 100644 --- a/internal/workspace/workspace.go +++ b/internal/workspace/workspace.go @@ -65,6 +65,7 @@ type Workspace interface { ListSessions(ctx context.Context) ([]session.Session, error) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) DeleteSession(ctx context.Context, sessionID string) error + UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error CreateAgentToolSessionID(messageID, toolCallID string) string ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)