feat: Allow resuming previous AI sessions (#3407)
Codespell / Check for spelling errors (push) Has been cancelled
build-docker / publish (push) Has been cancelled
Install / install (depot-ubuntu-24.04) (push) Has been cancelled
Install / install (macos-14) (push) Has been cancelled
Nix / check (push) Has been cancelled
Nix / build-test (push) Has been cancelled
Rust / format (push) Has been cancelled
Rust / build (depot-ubuntu-24.04) (push) Has been cancelled
Rust / build (macos-14) (push) Has been cancelled
Rust / build (windows-latest) (push) Has been cancelled
Rust / cross-compile (x86_64-unknown-illumos) (push) Has been cancelled
Rust / unit-test (depot-ubuntu-24.04) (push) Has been cancelled
Rust / unit-test (macos-14) (push) Has been cancelled
Rust / unit-test (windows-latest) (push) Has been cancelled
Rust / check (depot-ubuntu-24.04) (push) Has been cancelled
Rust / check (macos-14) (push) Has been cancelled
Rust / check (windows-latest) (push) Has been cancelled
Rust / integration-test (push) Has been cancelled
Rust / clippy (push) Has been cancelled
Shellcheck / shellcheck (push) Has been cancelled

This PR introduces session continuation to Atuin AI.

* Conversations with Atuin AI are stored in a local SQLite database
* Upon startup, Atuin AI tries to find a session to resume based on its
directory/workspace and the time since the last event
* If found, Atuin AI will show a note that the session has been resumed,
and an event is added to help the LLM know where the invocation
boundaries are
* If not, Atuin AI will create a new conversation
* The user can create a new conversation with `/new`
* The new setting `ai.session_continue_minutes`, which defaults to `60`,
controls how old the last event in a session can be before it's no
longer considered for automatic resuming.

<img width="1055" height="593" alt="image"
src="https://github.com/user-attachments/assets/3f9ff01a-ef64-44a9-b0e2-3a4252c5746f"
/>

## Architecture

A new `SessionService` trait defines an API contract for a service that
can manage session data. `LocalSessionService` implements this, with
`DaemonSessionService` a possible future extension point.

`SessionManager` owns a `dyn SessionService` and delegates as
appropriate.
This commit is contained in:
Michelle Tilley
2026-04-14 16:03:08 -07:00
committed by GitHub
parent 06ae8775f2
commit fd188da879
23 changed files with 2624 additions and 163 deletions
+6 -4
View File
@@ -13,10 +13,12 @@ Before working on anything, we suggest taking a copy of your Atuin data director
While data directory backups are always a good idea, you can instruct Atuin to use custom path using the following environment variables:
```shell
export ATUIN_RECORD_STORE_PATH=/tmp/atuin_records.db
export ATUIN_DB_PATH=/tmp/atuin_dev.db
export ATUIN_KV__DB_PATH=/tmp/atuin_kv.db
export ATUIN_SCRIPTS__DB_PATH=/tmp/atuin_scripts.db
export ATUIN_RECORD_STORE_PATH=/tmp/atuin_records.db # path to primary record store
export ATUIN_DB_PATH=/tmp/atuin_dev.db # path to materialized history database
export ATUIN_KV__DB_PATH=/tmp/atuin_kv.db # path to key-value store
export ATUIN_SCRIPTS__DB_PATH=/tmp/atuin_scripts.db # path to scripts database
export ATUIN_AI__DB_PATH=/tmp/atuin_ai_sessions.db # path to AI sessions database
export ATUIN_META__DB_PATH=/tmp/atuin_meta.db # path to meta database
```
It is also recommended to update your `$PATH` so that the pre-exec scripts would use the locally built version:
Generated
+20 -4
View File
@@ -271,14 +271,18 @@ name = "atuin-ai"
version = "18.14.1"
dependencies = [
"async-stream",
"async-trait",
"atuin-client",
"atuin-common",
"chrono",
"chrono-humanize",
"clap",
"crossterm",
"directories",
"eventsource-stream",
"eye_declare",
"eyre",
"fs-err",
"futures",
"glob-match",
"pretty_assertions",
@@ -290,6 +294,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"sqlx",
"tempfile",
"thiserror 2.0.18",
"time",
@@ -823,11 +828,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]
[[package]]
name = "chrono-humanize"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "799627e6b4d27827a814e837b9d8a504832086081806d45b1afa34dc982b023b"
dependencies = [
"chrono",
]
[[package]]
name = "cipher"
version = "0.4.4"
@@ -1516,9 +1532,9 @@ dependencies = [
[[package]]
name = "eye_declare"
version = "0.4.0"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9abe8051754adccf30ac4a0d54ce083a645fee7d4fc6c78d9d9770821bad45d"
checksum = "cd705fa26778c4cd8cd93f08b76986495601e5fc7039ff0f80499d0f1398ca62"
dependencies = [
"crossterm",
"eye_declare_macros",
@@ -1532,9 +1548,9 @@ dependencies = [
[[package]]
name = "eye_declare_macros"
version = "0.4.0"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39251ef16365f347032ab2344ad806f64d59f29fb171b4bafd05595fbda2604d"
checksum = "ae446305ea9f3f4679bd632a43e69eed48ba5484d5d692882a4c43e4666fe25d"
dependencies = [
"proc-macro2",
"quote",
+6 -1
View File
@@ -17,6 +17,7 @@ default = []
tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"]
[dependencies]
async-trait = { workspace = true }
atuin-client = { workspace = true }
atuin-common = { workspace = true }
tokio = { workspace = true }
@@ -36,6 +37,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
crossterm = { workspace = true, features = ["use-dev-tty", "event-stream"] }
ratatui = { workspace = true }
fs-err = { workspace = true }
futures = "0.3"
eventsource-stream = "0.2"
pulldown-cmark = "0.13.0"
@@ -43,7 +45,7 @@ async-stream = "0.3"
uuid = { workspace = true }
tui-textarea-2 = "0.10.2"
unicode-width = "0.2"
eye_declare = "0.4"
eye_declare = "0.4.2"
ratatui-core = "0.1"
ratatui-widgets = "0.3"
thiserror = { workspace = true }
@@ -55,8 +57,11 @@ toml_edit = { workspace = true }
tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true }
tree-sitter-bash = { version = "0.25.1", optional = true }
tree-sitter-fish = { version = "3.6.0", optional = true }
sqlx = { workspace = true, features = ["sqlite"] }
typed-builder = { workspace = true }
vt100 = { workspace = true }
chrono = "0.4"
chrono-humanize = "0.2"
[dev-dependencies]
pretty_assertions = { workspace = true }
@@ -0,0 +1,32 @@
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
head_id TEXT,
server_session_id TEXT,
directory TEXT,
git_root TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
archived_at INTEGER
);
CREATE INDEX idx_sessions_directory ON sessions(directory);
CREATE INDEX idx_sessions_git_root ON sessions(git_root);
CREATE INDEX idx_sessions_updated_at ON sessions(updated_at);
CREATE INDEX idx_sessions_created_at ON sessions(created_at);
CREATE TABLE IF NOT EXISTS session_events (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
parent_id TEXT,
invocation_id TEXT NOT NULL,
event_type TEXT NOT NULL,
event_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id)
);
CREATE INDEX idx_session_events_session_id ON session_events(session_id);
CREATE INDEX idx_session_events_parent_id ON session_events(parent_id);
CREATE INDEX idx_session_events_invocation_id ON session_events(invocation_id);
CREATE INDEX idx_session_events_created_at ON session_events(created_at);
+69 -5
View File
@@ -2,6 +2,7 @@ use std::path::PathBuf;
use std::sync::mpsc;
use crate::context::{AppContext, ClientContext};
use crate::session::{LocalSessionService, SessionManager, SessionService};
use crate::tui::dispatch;
use crate::tui::events::AiTuiEvent;
use crate::tui::state::{ExitAction, Session};
@@ -83,7 +84,7 @@ pub(crate) async fn run(
capabilities: settings.ai.capabilities.clone(),
};
let action = run_inline_tui(ctx, initial_command).await?;
let action = run_inline_tui(ctx, initial_command, settings).await?;
emit_shell_result(action, output_for_hook);
Ok(())
@@ -147,12 +148,74 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu
// ───────────────────────────────────────────────────────────────────
async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Result<Action> {
async fn run_inline_tui(
ctx: AppContext,
initial_prompt: Option<String>,
settings: &atuin_client::settings::Settings,
) -> Result<Action> {
let client_ctx = ClientContext::detect();
let (tx, rx) = mpsc::channel::<AiTuiEvent>();
// Open the session service and check for a resumable session
let service = LocalSessionService::open(&settings.ai.db_path, settings.local_timeout)
.await
.context("failed to open AI session database")?;
let initial_state = Session::new(ctx.git_root.is_some());
let cwd = std::env::current_dir()
.ok()
.map(|p| p.to_string_lossy().into_owned());
let git_root_str = ctx
.git_root
.as_ref()
.map(|p| p.to_string_lossy().into_owned());
let session_window_mins = settings.ai.session_continue_minutes.max(0); // treat negative values as 0 to avoid confusion
let max_age_secs: i64 = session_window_mins * 60;
let resumable = service
.find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs)
.await?;
let (session_mgr, initial_state) = if let Some(stored) = resumable {
debug!(session_id = %stored.id, "resuming AI session");
let (mgr, events, server_sid, last_event_ts, invocation_id) =
SessionManager::resume(Box::new(service), &stored).await?;
// Only treat this as a meaningful resume if there are API-visible events
// (not just OutOfBandOutput or SystemContext).
let has_api_content = events.iter().any(|e| e.is_api_content());
if has_api_content {
let mut session = Session::new(ctx.git_root.is_some(), Some(invocation_id));
session.conversation.events = events;
session.conversation.session_id = server_sid;
// Inject an invocation boundary so the LLM knows prior messages
// are from an earlier interaction.
session.conversation.events.push(
crate::tui::state::ConversationEvent::SystemContext {
content: "[Note: The user has started a new invocation of Atuin AI. Prior messages from this session are from an earlier invocation.]".to_string(),
},
);
session.view_start_index = session.conversation.events.len();
session.is_resumed = true;
session.last_event_time =
last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0));
(mgr, session)
} else {
// No meaningful content — treat as a fresh session
debug!("resumable session has no API-visible content, starting fresh");
(
mgr,
Session::new(ctx.git_root.is_some(), Some(invocation_id)),
)
}
} else {
debug!("creating new AI session");
let mgr =
SessionManager::create_new(Box::new(service), cwd.as_deref(), git_root_str.as_deref());
(mgr, Session::new(ctx.git_root.is_some(), None))
};
let (tx, rx) = mpsc::channel::<AiTuiEvent>();
println!();
@@ -177,8 +240,9 @@ async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Resu
tokio::task::spawn_blocking(move || {
let tx = tx.clone();
let client_ctx = client_ctx;
let mut session_mgr = session_mgr;
while let Ok(event) = rx.recv() {
dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx);
dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx, &mut session_mgr);
}
});
+578
View File
@@ -0,0 +1,578 @@
//! Context window management for API requests.
//!
//! Full conversation events are always persisted to disk. This module handles
//! truncation at send time so the API payload stays within a character budget.
//!
//! Strategy: **frozen prefix + live tail**. The first N turns form a stable
//! prefix that stays identical across requests (maximizing prompt cache hits).
//! The most recent turns form the live tail. When the total exceeds the budget,
//! turns between prefix and tail are dropped with a truncation marker. The
//! prefix never shifts, avoiding cache invalidation.
use std::ops::Range;
use crate::tui::{ConversationEvent, events_to_messages};
/// Default character budget for the context window.
/// Roughly ~50K tokens at ~4 chars/token — generous enough that truncation
/// only kicks in for genuinely long sessions.
const DEFAULT_BUDGET_CHARS: usize = 200_000;
/// Number of initial turns to freeze as the stable prefix.
const FROZEN_PREFIX_TURNS: usize = 1;
/// Builds API messages from conversation events while respecting a character
/// budget using frozen prefix + live tail truncation.
pub(crate) struct ContextWindowBuilder {
budget: usize,
}
impl ContextWindowBuilder {
pub fn new(budget: usize) -> Self {
Self { budget }
}
pub fn with_default_budget() -> Self {
Self::new(DEFAULT_BUDGET_CHARS)
}
/// Build API messages from conversation events, applying the context
/// window budget. Returns the messages to send in the API request.
pub fn build(&self, events: &[ConversationEvent]) -> Vec<serde_json::Value> {
if events.is_empty() {
return Vec::new();
}
let turns = group_into_turns(events);
// Convert each turn's events to API messages independently.
// This is safe because the combining logic (Text + ToolCall merging)
// only operates within a single assistant response, which never
// spans turn boundaries.
let turn_messages: Vec<Vec<serde_json::Value>> = turns
.iter()
.map(|range| events_to_messages(&events[range.clone()]))
.collect();
let turn_chars: Vec<usize> = turn_messages.iter().map(|m| estimate_chars(m)).collect();
let total_chars: usize = turn_chars.iter().sum();
if total_chars <= self.budget {
return turn_messages.into_iter().flatten().collect();
}
// --- Over budget: apply frozen prefix + live tail ---
let prefix_count = FROZEN_PREFIX_TURNS.min(turns.len());
let prefix_chars: usize = turn_chars[..prefix_count].iter().sum();
let marker = truncation_marker();
let marker_chars = estimate_chars(std::slice::from_ref(&marker));
let mut remaining = self.budget.saturating_sub(prefix_chars + marker_chars);
// Work backwards from the end, accumulating tail turns that fit.
let mut tail_start = turns.len();
for i in (prefix_count..turns.len()).rev() {
if turn_chars[i] <= remaining {
remaining -= turn_chars[i];
tail_start = i;
} else {
break;
}
}
// Always include at least the most recent turn, even if it alone
// exceeds the budget — sending something is better than nothing.
if tail_start >= turns.len() && turns.len() > prefix_count {
tail_start = turns.len() - 1;
}
let mut result = Vec::new();
// Frozen prefix
for msgs in &turn_messages[..prefix_count] {
result.extend(msgs.iter().cloned());
}
// Truncation marker (only if turns were actually dropped)
if tail_start > prefix_count {
result.push(marker);
}
// Live tail
for msgs in &turn_messages[tail_start..] {
result.extend(msgs.iter().cloned());
}
result
}
}
/// Marker message inserted where turns were dropped. Uses user role since
/// the preceding prefix typically ends with an assistant message.
fn truncation_marker() -> serde_json::Value {
serde_json::json!({
"role": "user",
"content": "[Earlier conversation context was omitted to fit within the context window. The conversation continues below.]"
})
}
/// Group conversation events into turns. A new turn starts at each
/// `UserMessage` or `SystemContext` event. Everything between boundaries
/// belongs to the preceding turn (assistant text, tool calls, tool results,
/// out-of-band output).
fn group_into_turns(events: &[ConversationEvent]) -> Vec<Range<usize>> {
let mut turns = Vec::new();
let mut start = 0;
for (i, event) in events.iter().enumerate() {
if i > start
&& matches!(
event,
ConversationEvent::UserMessage { .. } | ConversationEvent::SystemContext { .. }
)
{
turns.push(start..i);
start = i;
}
}
if start < events.len() {
turns.push(start..events.len());
}
turns
}
/// Rough character-count estimate for a set of messages. Uses the JSON
/// serialization length as a proxy — not exact tokens, but proportional
/// and cheap to compute.
fn estimate_chars(messages: &[serde_json::Value]) -> usize {
messages.iter().map(|m| m.to_string().len()).sum()
}
#[cfg(test)]
mod tests {
use super::*;
fn user(content: &str) -> ConversationEvent {
ConversationEvent::UserMessage {
content: content.to_string(),
}
}
fn text(content: &str) -> ConversationEvent {
ConversationEvent::Text {
content: content.to_string(),
}
}
fn tool_call(id: &str, name: &str) -> ConversationEvent {
ConversationEvent::ToolCall {
id: id.to_string(),
name: name.to_string(),
input: serde_json::json!({"command": "ls"}),
}
}
fn tool_result(tool_use_id: &str, content: &str) -> ConversationEvent {
ConversationEvent::ToolResult {
tool_use_id: tool_use_id.to_string(),
content: content.to_string(),
is_error: false,
remote: false,
content_length: None,
}
}
fn system_context(content: &str) -> ConversationEvent {
ConversationEvent::SystemContext {
content: content.to_string(),
}
}
fn oob(content: &str) -> ConversationEvent {
ConversationEvent::OutOfBandOutput {
name: "test".to_string(),
command: None,
content: content.to_string(),
}
}
// --- group_into_turns ---
#[test]
fn empty_events_produce_no_turns() {
assert!(group_into_turns(&[]).is_empty());
}
#[test]
fn single_user_message_is_one_turn() {
let events = vec![user("hello")];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..1]);
}
#[test]
fn user_assistant_is_one_turn() {
let events = vec![user("hello"), text("hi there")];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..2]);
}
#[test]
fn two_turns_split_at_user_message() {
let events = vec![
user("first"),
text("response 1"),
user("second"),
text("response 2"),
];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..2, 2..4]);
}
#[test]
fn tool_calls_and_results_stay_in_same_turn() {
let events = vec![
user("list files"),
text("Let me check"),
tool_call("tc1", "suggest_command"),
tool_result("tc1", "file1\nfile2"),
text("Here are your files"),
];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..5]);
}
#[test]
fn system_context_starts_new_turn() {
let events = vec![
user("hello"),
text("hi"),
system_context("invocation boundary"),
user("next question"),
text("answer"),
];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..2, 2..3, 3..5]);
}
#[test]
fn oob_events_stay_in_current_turn() {
let events = vec![user("hello"), oob("some output"), text("response")];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..3]);
}
#[test]
fn leading_text_without_user_message() {
// Edge case: events start with assistant text (shouldn't happen
// normally but handle gracefully)
let events = vec![text("orphaned"), user("hello"), text("hi")];
let turns = group_into_turns(&events);
assert_eq!(turns, vec![0..1, 1..3]);
}
// --- ContextWindowBuilder ---
#[test]
fn empty_events_produce_empty_messages() {
let builder = ContextWindowBuilder::with_default_budget();
assert!(builder.build(&[]).is_empty());
}
#[test]
fn under_budget_returns_all_messages() {
let events = vec![user("hello"), text("hi"), user("how are you"), text("good")];
let builder = ContextWindowBuilder::with_default_budget();
let messages = builder.build(&events);
// Should produce 4 messages (2 user + 2 assistant)
assert_eq!(messages.len(), 4);
assert_eq!(messages[0]["role"], "user");
assert_eq!(messages[0]["content"], "hello");
assert_eq!(messages[1]["role"], "assistant");
assert_eq!(messages[1]["content"], "hi");
assert_eq!(messages[2]["role"], "user");
assert_eq!(messages[2]["content"], "how are you");
assert_eq!(messages[3]["role"], "assistant");
assert_eq!(messages[3]["content"], "good");
}
#[test]
fn over_budget_truncates_middle_turns() {
// Create events where each turn has known content. Use a tiny
// budget so truncation is triggered with just a few turns.
let events = vec![
user("turn-1-user"),
text("turn-1-assistant"),
user("turn-2-user"),
text("turn-2-assistant"),
user("turn-3-user"),
text("turn-3-assistant"),
user("turn-4-user"),
text("turn-4-assistant-final"),
];
// Calculate sizes to set budget that keeps turn 1 (prefix) + turn 4 (tail)
// but drops turns 2 and 3.
let all_messages = events_to_messages(&events);
let total_chars: usize = all_messages.iter().map(|m| m.to_string().len()).sum();
// Set budget to roughly half — enough for prefix + last turn + marker
let turn1_msgs = events_to_messages(&events[0..2]);
let turn4_msgs = events_to_messages(&events[6..8]);
let marker_chars = estimate_chars(std::slice::from_ref(&truncation_marker()));
let needed = estimate_chars(&turn1_msgs) + estimate_chars(&turn4_msgs) + marker_chars;
// Budget allows prefix + marker + last turn but not the middle turns
assert!(
needed < total_chars,
"test setup: needed ({needed}) should be less than total ({total_chars})"
);
let builder = ContextWindowBuilder::new(needed + 10); // small margin
let messages = builder.build(&events);
// Should have: turn 1 (2 msgs) + marker (1 msg) + turn 4 (2 msgs) = 5
assert_eq!(messages.len(), 5, "expected prefix + marker + tail");
assert_eq!(messages[0]["content"], "turn-1-user");
assert_eq!(messages[1]["content"], "turn-1-assistant");
assert!(
messages[2]["content"].as_str().unwrap().contains("omitted"),
"middle message should be truncation marker"
);
assert_eq!(messages[3]["content"], "turn-4-user");
assert_eq!(messages[4]["content"], "turn-4-assistant-final");
}
#[test]
fn very_tight_budget_keeps_prefix_and_last_turn() {
let events = vec![
user("first"),
text("response-1"),
user("second"),
text("response-2"),
user("third"),
text("response-3"),
];
// Budget of 1 — forces the "always include last turn" fallback
let builder = ContextWindowBuilder::new(1);
let messages = builder.build(&events);
// Should have prefix (turn 1) + marker + last turn (turn 3)
assert!(
messages.len() >= 3,
"should have at least prefix + marker + tail"
);
// First message should be from turn 1
assert_eq!(messages[0]["content"], "first");
// Last messages should be from the final turn
let last = messages.last().unwrap();
assert_eq!(last["content"], "response-3");
}
#[test]
fn single_turn_always_returned() {
let events = vec![user("hello"), text("hi there")];
// Even with a tiny budget, the single turn must be returned
let builder = ContextWindowBuilder::new(1);
let messages = builder.build(&events);
assert_eq!(messages.len(), 2);
}
#[test]
fn tool_calls_preserved_through_truncation() {
let events = vec![
// Turn 1: simple exchange
user("turn 1"),
text("response 1"),
// Turn 2: with tool calls (will be dropped)
user("turn 2"),
text("checking"),
tool_call("tc1", "suggest_command"),
tool_result("tc1", "output"),
text("done"),
// Turn 3: final turn (kept in tail)
user("turn 3"),
text("final response"),
];
// Budget that fits turn 1 + turn 3 + marker but not turn 2
let turn1 = events_to_messages(&events[0..2]);
let turn3 = events_to_messages(&events[7..9]);
let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker()));
let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10;
let builder = ContextWindowBuilder::new(budget);
let messages = builder.build(&events);
// Verify turn 2 (the tool call turn) was dropped
let has_tool_use = messages.iter().any(|m| {
m["content"]
.as_array()
.is_some_and(|arr| arr.iter().any(|b| b["type"] == "tool_use"))
});
assert!(!has_tool_use, "tool call turn should have been truncated");
// Verify first and last turns present
assert_eq!(messages[0]["content"], "turn 1");
assert_eq!(messages.last().unwrap()["content"], "final response");
}
#[test]
fn tail_accumulates_multiple_turns_when_budget_allows() {
// Use long content so turn sizes dwarf the truncation marker.
let padding = "x".repeat(500);
let events = vec![
user(&format!("turn-1-user-{padding}")),
text(&format!("turn-1-response-{padding}")),
user(&format!("turn-2-user-{padding}")),
text(&format!("turn-2-response-{padding}")),
user(&format!("turn-3-user-{padding}")),
text(&format!("turn-3-response-{padding}")),
user(&format!("turn-4-user-{padding}")),
text(&format!("turn-4-response-{padding}")),
];
// Budget that fits everything except turn 2
let all = events_to_messages(&events);
let total = estimate_chars(&all);
let turn2 = events_to_messages(&events[2..4]);
let turn2_chars = estimate_chars(&turn2);
let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker()));
let budget = total - turn2_chars + marker_cost + 5;
assert!(
budget < total,
"budget must be less than total for truncation to trigger"
);
let builder = ContextWindowBuilder::new(budget);
let messages = builder.build(&events);
// Should have: prefix (t1: 2 msgs) + marker (1 msg) + t3 (2 msgs) + t4 (2 msgs) = 7
// (turn 2 dropped)
assert_eq!(messages.len(), 7);
assert!(
messages[0]["content"]
.as_str()
.unwrap()
.starts_with("turn-1-user-")
);
assert!(
messages[1]["content"]
.as_str()
.unwrap()
.starts_with("turn-1-response-")
);
assert!(messages[2]["content"].as_str().unwrap().contains("omitted"));
assert!(
messages[3]["content"]
.as_str()
.unwrap()
.starts_with("turn-3-user-")
);
assert!(
messages[4]["content"]
.as_str()
.unwrap()
.starts_with("turn-3-response-")
);
assert!(
messages[5]["content"]
.as_str()
.unwrap()
.starts_with("turn-4-user-")
);
assert!(
messages[6]["content"]
.as_str()
.unwrap()
.starts_with("turn-4-response-")
);
}
#[test]
fn no_marker_when_no_turns_dropped() {
// Two turns, both fit in budget
let events = vec![user("a"), text("b"), user("c"), text("d")];
let builder = ContextWindowBuilder::with_default_budget();
let messages = builder.build(&events);
// No truncation marker
assert_eq!(messages.len(), 4);
assert!(
!messages
.iter()
.any(|m| m["content"].as_str().is_some_and(|s| s.contains("omitted")))
);
}
#[test]
fn tool_use_and_tool_result_never_split() {
// Invariant: a tool_use and its matching tool_result must always
// end up in the same turn, so truncation can't orphan one from
// the other. This test verifies that ToolResult does NOT start
// a new turn boundary.
let padding = "x".repeat(500);
let events = vec![
// Turn 1 (prefix)
user(&format!("turn-1-{padding}")),
text(&format!("resp-1-{padding}")),
// Turn 2: contains a tool_use → tool_result pair (will be dropped)
user(&format!("turn-2-{padding}")),
text("checking"),
tool_call("tc1", "suggest_command"),
tool_result("tc1", &format!("output-{padding}")),
text(&format!("done-{padding}")),
// Turn 3 (tail)
user(&format!("turn-3-{padding}")),
text(&format!("resp-3-{padding}")),
];
// Budget that fits turn 1 + turn 3 + marker, but not turn 2
let turn1 = events_to_messages(&events[0..2]);
let turn3 = events_to_messages(&events[7..9]);
let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker()));
let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10;
let builder = ContextWindowBuilder::new(budget);
let messages = builder.build(&events);
// Verify: every tool_use has a matching tool_result, and vice versa
let tool_use_ids: Vec<&str> = messages
.iter()
.filter_map(|m| m["content"].as_array())
.flatten()
.filter(|b| b["type"] == "tool_use")
.filter_map(|b| b["id"].as_str())
.collect();
let tool_result_ids: Vec<&str> = messages
.iter()
.filter_map(|m| m["content"].as_array())
.flatten()
.filter(|b| b["type"] == "tool_result")
.filter_map(|b| b["tool_use_id"].as_str())
.collect();
assert_eq!(
tool_use_ids, tool_result_ids,
"every tool_use must have a matching tool_result (and vice versa)"
);
// Turn 2 was dropped entirely, so no tool IDs should be present
assert!(
!tool_use_ids.contains(&"tc1"),
"dropped turn's tool_use should not appear"
);
}
}
+376
View File
@@ -0,0 +1,376 @@
//! Manual serialization for ConversationEvent to/from storage format.
//!
//! The storage format is decoupled from the Rust enum so the two can evolve
//! independently. Each event is stored as an `(event_type, event_data)` pair
//! where `event_data` is a JSON string.
use eyre::{Result, eyre};
use serde_json::Value;
use crate::tui::ConversationEvent;
/// Serialize a ConversationEvent into an (event_type, event_data_json) pair
/// suitable for database storage.
pub(crate) fn serialize_event(event: &ConversationEvent) -> (String, String) {
match event {
ConversationEvent::UserMessage { content } => (
"user_message".to_string(),
serde_json::json!({ "content": content }).to_string(),
),
ConversationEvent::Text { content } => (
"text".to_string(),
serde_json::json!({ "content": content }).to_string(),
),
ConversationEvent::ToolCall { id, name, input } => (
"tool_call".to_string(),
serde_json::json!({
"id": id,
"name": name,
"input": input,
})
.to_string(),
),
ConversationEvent::ToolResult {
tool_use_id,
content,
is_error,
remote,
content_length,
} => (
"tool_result".to_string(),
serde_json::json!({
"tool_use_id": tool_use_id,
"content": content,
"is_error": is_error,
"remote": remote,
"content_length": content_length,
})
.to_string(),
),
ConversationEvent::OutOfBandOutput {
name,
command,
content,
} => (
"out_of_band_output".to_string(),
serde_json::json!({
"name": name,
"command": command,
"content": content,
})
.to_string(),
),
ConversationEvent::SystemContext { content } => (
"system_context".to_string(),
serde_json::json!({ "content": content }).to_string(),
),
}
}
/// Deserialize an (event_type, event_data_json) pair from storage back into a
/// ConversationEvent.
pub(crate) fn deserialize_event(event_type: &str, event_data: &str) -> Result<ConversationEvent> {
let data: Value = serde_json::from_str(event_data)
.map_err(|e| eyre!("failed to parse event_data JSON: {e}"))?;
match event_type {
"user_message" => Ok(ConversationEvent::UserMessage {
content: json_string(&data, "content")?,
}),
"text" => Ok(ConversationEvent::Text {
content: json_string(&data, "content")?,
}),
"tool_call" => Ok(ConversationEvent::ToolCall {
id: json_string(&data, "id")?,
name: json_string(&data, "name")?,
input: data
.get("input")
.cloned()
.ok_or_else(|| eyre!("tool_call missing 'input' field"))?,
}),
"tool_result" => Ok(ConversationEvent::ToolResult {
tool_use_id: json_string(&data, "tool_use_id")?,
content: json_string(&data, "content")?,
is_error: data
.get("is_error")
.and_then(Value::as_bool)
.ok_or_else(|| eyre!("tool_result missing 'is_error' field"))?,
remote: data.get("remote").and_then(Value::as_bool).unwrap_or(false),
content_length: data
.get("content_length")
.and_then(Value::as_u64)
.map(|v| v as usize),
}),
"out_of_band_output" => Ok(ConversationEvent::OutOfBandOutput {
name: json_string(&data, "name")?,
command: data
.get("command")
.and_then(|v| if v.is_null() { None } else { v.as_str() })
.map(String::from),
content: json_string(&data, "content")?,
}),
"system_context" => Ok(ConversationEvent::SystemContext {
content: json_string(&data, "content")?,
}),
other => Err(eyre!("unknown event type: {other}")),
}
}
fn json_string(data: &Value, field: &str) -> Result<String> {
data.get(field)
.and_then(Value::as_str)
.map(String::from)
.ok_or_else(|| eyre!("missing or non-string field '{field}'"))
}
#[cfg(test)]
mod tests {
use super::*;
fn round_trip(event: &ConversationEvent) -> ConversationEvent {
let (event_type, event_data) = serialize_event(event);
deserialize_event(&event_type, &event_data).unwrap()
}
#[test]
fn test_user_message() {
let event = ConversationEvent::UserMessage {
content: "hello world".to_string(),
};
let result = round_trip(&event);
assert!(
matches!(result, ConversationEvent::UserMessage { content } if content == "hello world")
);
}
#[test]
fn test_text() {
let event = ConversationEvent::Text {
content: "response text".to_string(),
};
let result = round_trip(&event);
assert!(
matches!(result, ConversationEvent::Text { content } if content == "response text")
);
}
#[test]
fn test_tool_call() {
let input = serde_json::json!({"command": "ls -la", "danger": "low"});
let event = ConversationEvent::ToolCall {
id: "tc_123".to_string(),
name: "suggest_command".to_string(),
input: input.clone(),
};
let result = round_trip(&event);
match result {
ConversationEvent::ToolCall {
id,
name,
input: result_input,
} => {
assert_eq!(id, "tc_123");
assert_eq!(name, "suggest_command");
assert_eq!(result_input, input);
}
_ => panic!("expected ToolCall"),
}
}
#[test]
fn test_tool_result() {
let event = ConversationEvent::ToolResult {
tool_use_id: "tc_123".to_string(),
content: "file contents here".to_string(),
is_error: false,
remote: false,
content_length: None,
};
let result = round_trip(&event);
match result {
ConversationEvent::ToolResult {
tool_use_id,
content,
is_error,
remote,
content_length,
} => {
assert_eq!(tool_use_id, "tc_123");
assert_eq!(content, "file contents here");
assert!(!is_error);
assert!(!remote);
assert!(content_length.is_none());
}
_ => panic!("expected ToolResult"),
}
}
#[test]
fn test_tool_result_error() {
let event = ConversationEvent::ToolResult {
tool_use_id: "tc_456".to_string(),
content: "permission denied".to_string(),
is_error: true,
remote: false,
content_length: None,
};
let result = round_trip(&event);
match result {
ConversationEvent::ToolResult { is_error, .. } => assert!(is_error),
_ => panic!("expected ToolResult"),
}
}
#[test]
fn test_tool_result_remote() {
let event = ConversationEvent::ToolResult {
tool_use_id: "tc_789".to_string(),
content: "ref:abc123".to_string(),
is_error: false,
remote: true,
content_length: Some(4096),
};
let result = round_trip(&event);
match result {
ConversationEvent::ToolResult {
remote,
content_length,
..
} => {
assert!(remote);
assert_eq!(content_length, Some(4096));
}
_ => panic!("expected ToolResult"),
}
}
#[test]
fn test_tool_result_backwards_compat() {
// Old stored data without remote/content_length fields should deserialize
// with defaults (remote=false, content_length=None)
let event = deserialize_event(
"tool_result",
r#"{"tool_use_id":"tc_old","content":"old result","is_error":false}"#,
)
.unwrap();
match event {
ConversationEvent::ToolResult {
remote,
content_length,
..
} => {
assert!(!remote);
assert!(content_length.is_none());
}
_ => panic!("expected ToolResult"),
}
}
#[test]
fn test_out_of_band_with_command() {
let event = ConversationEvent::OutOfBandOutput {
name: "System".to_string(),
command: Some("/help".to_string()),
content: "help text".to_string(),
};
let result = round_trip(&event);
match result {
ConversationEvent::OutOfBandOutput {
name,
command,
content,
} => {
assert_eq!(name, "System");
assert_eq!(command.as_deref(), Some("/help"));
assert_eq!(content, "help text");
}
_ => panic!("expected OutOfBandOutput"),
}
}
#[test]
fn test_out_of_band_without_command() {
let event = ConversationEvent::OutOfBandOutput {
name: "System".to_string(),
command: None,
content: "some output".to_string(),
};
let result = round_trip(&event);
match result {
ConversationEvent::OutOfBandOutput { command, .. } => {
assert!(command.is_none());
}
_ => panic!("expected OutOfBandOutput"),
}
}
#[test]
fn test_unknown_event_type() {
let result = deserialize_event("banana", "{}");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("unknown event type")
);
}
#[test]
fn test_invalid_json() {
let result = deserialize_event("text", "not json");
assert!(result.is_err());
}
#[test]
fn test_missing_field() {
let result = deserialize_event("text", "{}");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("content"));
}
#[test]
fn test_text_with_special_characters() {
let event = ConversationEvent::Text {
content: "line1\nline2\ttab \"quotes\" \\backslash 🎉".to_string(),
};
let result = round_trip(&event);
assert!(
matches!(result, ConversationEvent::Text { content } if content == "line1\nline2\ttab \"quotes\" \\backslash 🎉")
);
}
#[test]
fn test_tool_call_with_nested_input() {
let input = serde_json::json!({
"command": "echo 'hello'",
"nested": { "a": [1, 2, 3], "b": null }
});
let event = ConversationEvent::ToolCall {
id: "tc_1".to_string(),
name: "execute_shell_command".to_string(),
input: input.clone(),
};
let result = round_trip(&event);
match result {
ConversationEvent::ToolCall {
input: result_input,
..
} => {
assert_eq!(result_input, input);
}
_ => panic!("expected ToolCall"),
}
}
#[test]
fn test_system_context() {
let event = ConversationEvent::SystemContext {
content: "[system: new invocation started]".to_string(),
};
let result = round_trip(&event);
assert!(
matches!(result, ConversationEvent::SystemContext { content } if content == "[system: new invocation started]")
);
}
}
+4
View File
@@ -1,6 +1,10 @@
pub mod commands;
pub(crate) mod context;
pub(crate) mod context_window;
pub(crate) mod event_serde;
pub(crate) mod permissions;
pub(crate) mod session;
pub(crate) mod store;
pub(crate) mod stream;
pub(crate) mod tools;
pub(crate) mod tui;
+482
View File
@@ -0,0 +1,482 @@
//! Session service abstraction and manager.
//!
//! The TUI interacts with sessions through `SessionManager`, which wraps a
//! `SessionService` trait. Today the only implementation is `LocalSessionService`
//! (direct SQLite). When the daemon owns session state, a gRPC-backed
//! implementation can be swapped in without changing the TUI code.
use async_trait::async_trait;
use eyre::Result;
use crate::event_serde;
use crate::store::{AiSessionStore, StoredEvent, StoredSession};
use crate::tui::ConversationEvent;
// ---------------------------------------------------------------------------
// Trait
// ---------------------------------------------------------------------------
#[async_trait]
pub(crate) trait SessionService: Send + Sync {
async fn create_session(
&self,
id: &str,
directory: Option<&str>,
git_root: Option<&str>,
) -> Result<StoredSession>;
async fn find_resumable(
&self,
directory: Option<&str>,
git_root: Option<&str>,
max_age_secs: i64,
) -> Result<Option<StoredSession>>;
async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>>;
async fn append_event(
&self,
session_id: &str,
event_id: &str,
parent_id: Option<&str>,
invocation_id: &str,
event_type: &str,
event_data: &str,
) -> Result<()>;
async fn update_server_session_id(
&self,
session_id: &str,
server_session_id: &str,
) -> Result<()>;
async fn archive(&self, session_id: &str) -> Result<()>;
}
// ---------------------------------------------------------------------------
// Local implementation (direct SQLite)
// ---------------------------------------------------------------------------
pub(crate) struct LocalSessionService {
store: AiSessionStore,
}
impl LocalSessionService {
pub async fn open(path: impl AsRef<std::path::Path>, timeout: f64) -> Result<Self> {
let store = AiSessionStore::new(path, timeout).await?;
Ok(Self { store })
}
}
#[async_trait]
impl SessionService for LocalSessionService {
async fn create_session(
&self,
id: &str,
directory: Option<&str>,
git_root: Option<&str>,
) -> Result<StoredSession> {
self.store.create_session(id, directory, git_root).await
}
async fn find_resumable(
&self,
directory: Option<&str>,
git_root: Option<&str>,
max_age_secs: i64,
) -> Result<Option<StoredSession>> {
self.store
.find_resumable_session(directory, git_root, max_age_secs)
.await
}
async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> {
self.store.load_events(session_id).await
}
async fn append_event(
&self,
session_id: &str,
event_id: &str,
parent_id: Option<&str>,
invocation_id: &str,
event_type: &str,
event_data: &str,
) -> Result<()> {
self.store
.append_event(
session_id,
event_id,
parent_id,
invocation_id,
event_type,
event_data,
)
.await
}
async fn update_server_session_id(
&self,
session_id: &str,
server_session_id: &str,
) -> Result<()> {
self.store
.update_server_session_id(session_id, server_session_id)
.await
}
async fn archive(&self, session_id: &str) -> Result<()> {
self.store.archive_session(session_id).await
}
}
// ---------------------------------------------------------------------------
// SessionManager
// ---------------------------------------------------------------------------
/// High-level session manager used by the TUI dispatch loop.
///
/// Owns the current session identity, tracks what has been persisted, and
/// handles serialization between `ConversationEvent` and the storage format.
pub(crate) struct SessionManager {
service: Box<dyn SessionService>,
session_id: String,
invocation_id: String,
/// Number of events already persisted. `persist_events` only writes the
/// delta from this index onward.
persisted_count: usize,
/// ID of the last persisted event, used as `parent_id` for the next one.
head_id: Option<String>,
/// Stored for creating a new session on `/new`.
directory: Option<String>,
git_root: Option<String>,
/// Whether the session row has been created in the database. New sessions
/// are deferred until the first event is persisted, so empty sessions
/// don't linger and get spuriously resumed.
persisted_to_db: bool,
}
impl SessionManager {
/// Create a new session manager. The database row is deferred until the
/// first event is persisted.
pub fn create_new(
service: Box<dyn SessionService>,
directory: Option<&str>,
git_root: Option<&str>,
) -> Self {
let session_id = atuin_common::utils::uuid_v7().to_string();
let invocation_id = atuin_common::utils::uuid_v7().to_string();
Self {
service,
session_id,
invocation_id,
persisted_count: 0,
head_id: None,
directory: directory.map(String::from),
git_root: git_root.map(String::from),
persisted_to_db: false,
}
}
/// Load an existing session and return a manager for it, along with the
/// deserialized conversation events, the server session ID, and the
/// timestamp of the last stored event.
pub async fn resume(
service: Box<dyn SessionService>,
stored: &StoredSession,
) -> Result<(
Self,
Vec<ConversationEvent>,
Option<String>,
Option<i64>,
String,
)> {
let invocation_id = atuin_common::utils::uuid_v7().to_string();
let stored_events = service.load_events(&stored.id).await?;
let mut events = Vec::with_capacity(stored_events.len());
let mut last_event_id = None;
let mut last_event_ts = None;
for se in &stored_events {
events.push(event_serde::deserialize_event(
&se.event_type,
&se.event_data,
)?);
last_event_id = Some(se.id.clone());
last_event_ts = Some(se.created_at);
}
let manager = Self {
service,
session_id: stored.id.clone(),
invocation_id: invocation_id.clone(),
persisted_count: events.len(),
head_id: last_event_id,
directory: stored.directory.clone(),
git_root: stored.git_root.clone(),
persisted_to_db: true,
};
Ok((
manager,
events,
stored.server_session_id.clone(),
last_event_ts,
invocation_id,
))
}
/// Ensure the session row exists in the database.
async fn ensure_persisted(&mut self) -> Result<()> {
if !self.persisted_to_db {
self.service
.create_session(
&self.session_id,
self.directory.as_deref(),
self.git_root.as_deref(),
)
.await?;
self.persisted_to_db = true;
}
Ok(())
}
/// Persist any new events since the last persist call.
pub async fn persist_events(&mut self, events: &[ConversationEvent]) -> Result<()> {
if self.persisted_count >= events.len() {
return Ok(());
}
self.ensure_persisted().await?;
for event in &events[self.persisted_count..] {
let event_id = atuin_common::utils::uuid_v7().to_string();
let (event_type, event_data) = event_serde::serialize_event(event);
self.service
.append_event(
&self.session_id,
&event_id,
self.head_id.as_deref(),
&self.invocation_id,
&event_type,
&event_data,
)
.await?;
self.head_id = Some(event_id);
self.persisted_count += 1;
}
Ok(())
}
/// Persist the server session ID if it has changed.
pub async fn persist_server_session_id(&mut self, server_session_id: &str) -> Result<()> {
self.ensure_persisted().await?;
self.service
.update_server_session_id(&self.session_id, server_session_id)
.await
}
/// Archive the current session (for `/new` command).
#[allow(dead_code)] // used in tests; will be used by dispatch for `/new`
pub async fn archive(&self) -> Result<()> {
if self.persisted_to_db {
self.service.archive(&self.session_id).await?;
}
Ok(())
}
/// Archive the current session and reset to a fresh one.
/// The new session row is deferred until the first event is persisted.
pub async fn archive_and_reset(&mut self) -> Result<()> {
if self.persisted_to_db {
self.service.archive(&self.session_id).await?;
}
self.session_id = atuin_common::utils::uuid_v7().to_string();
self.invocation_id = atuin_common::utils::uuid_v7().to_string();
self.persisted_count = 0;
self.head_id = None;
self.persisted_to_db = false;
Ok(())
}
#[allow(dead_code)] // used in tests; part of public API for dispatch/daemon
pub fn session_id(&self) -> &str {
&self.session_id
}
#[allow(dead_code)] // used in tests; part of public API for dispatch/daemon
pub fn invocation_id(&self) -> &str {
&self.invocation_id
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn test_service() -> Box<dyn SessionService> {
let svc = LocalSessionService::open("sqlite::memory:", 2.0)
.await
.unwrap();
Box::new(svc)
}
#[tokio::test]
async fn test_create_new_and_persist() {
let service = test_service().await;
let mut mgr = SessionManager::create_new(service, Some("/tmp"), None);
let events = vec![
ConversationEvent::UserMessage {
content: "hello".to_string(),
},
ConversationEvent::Text {
content: "hi there".to_string(),
},
];
mgr.persist_events(&events).await.unwrap();
// Persist again with no new events — should be a no-op
mgr.persist_events(&events).await.unwrap();
}
#[tokio::test]
async fn test_create_and_resume() {
// Create a session and persist some events
let svc = LocalSessionService::open("sqlite::memory:", 2.0)
.await
.unwrap();
let session_id = atuin_common::utils::uuid_v7().to_string();
svc.create_session(&session_id, Some("/project"), Some("/project"))
.await
.unwrap();
let events = vec![
ConversationEvent::UserMessage {
content: "how do I list files?".to_string(),
},
ConversationEvent::Text {
content: "Use ls".to_string(),
},
ConversationEvent::ToolCall {
id: "tc_1".to_string(),
name: "suggest_command".to_string(),
input: serde_json::json!({"command": "ls -la"}),
},
];
// Persist events manually through the service
let inv_id = "inv-1";
let mut parent: Option<String> = None;
for event in &events {
let eid = atuin_common::utils::uuid_v7().to_string();
let (etype, edata) = event_serde::serialize_event(event);
svc.append_event(&session_id, &eid, parent.as_deref(), inv_id, &etype, &edata)
.await
.unwrap();
parent = Some(eid);
}
svc.update_server_session_id(&session_id, "srv-abc")
.await
.unwrap();
// Now find and resume the session with a fresh service connection
let stored = svc
.find_resumable(Some("/project"), Some("/project"), 3600)
.await
.unwrap()
.expect("should find session");
let (mut mgr, loaded_events, server_sid, last_ts, _invocation_id) =
SessionManager::resume(Box::new(svc), &stored)
.await
.unwrap();
assert_eq!(loaded_events.len(), 3);
assert_eq!(server_sid.as_deref(), Some("srv-abc"));
assert_ne!(mgr.invocation_id(), inv_id, "new invocation ID on resume");
assert!(last_ts.is_some(), "should have a last event timestamp");
// Persisting again with the same events should be a no-op
mgr.persist_events(&loaded_events).await.unwrap();
}
#[tokio::test]
async fn test_incremental_persist() {
let service = test_service().await;
let mut mgr = SessionManager::create_new(service, Some("/tmp"), None);
let mut events = vec![ConversationEvent::UserMessage {
content: "first".to_string(),
}];
mgr.persist_events(&events).await.unwrap();
// Add more events and persist again — only the new ones should be written
events.push(ConversationEvent::Text {
content: "response".to_string(),
});
events.push(ConversationEvent::UserMessage {
content: "second".to_string(),
});
mgr.persist_events(&events).await.unwrap();
// Verify by loading through a fresh service (can't easily here since
// the service is moved, but the lack of errors confirms correctness)
}
#[tokio::test]
async fn test_archive() {
let svc = LocalSessionService::open("sqlite::memory:", 2.0)
.await
.unwrap();
let mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None);
mgr.archive().await.unwrap();
}
#[tokio::test]
async fn test_persist_server_session_id() {
let service = test_service().await;
let mut mgr = SessionManager::create_new(service, Some("/tmp"), None);
mgr.persist_server_session_id("srv-123").await.unwrap();
}
#[tokio::test]
async fn test_parent_chain_integrity() {
// Verify that persisted events form a proper parent chain
let svc = LocalSessionService::open("sqlite::memory:", 2.0)
.await
.unwrap();
let session_id = {
let mut mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None);
let events = vec![
ConversationEvent::UserMessage {
content: "a".to_string(),
},
ConversationEvent::Text {
content: "b".to_string(),
},
ConversationEvent::UserMessage {
content: "c".to_string(),
},
];
mgr.persist_events(&events).await.unwrap();
mgr.session_id().to_string()
};
// Re-open the store and load events to verify the chain
// (Can't do this with in-memory DB since it's gone, but the
// lack of FK constraint violations during persist confirms the
// parent_id values are valid)
let _ = session_id;
}
}
+522
View File
@@ -0,0 +1,522 @@
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;
use eyre::{Result, eyre};
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions};
use time::OffsetDateTime;
// Database row mappings — all columns are kept even if not yet read in
// non-test code, since they're part of the schema and used in tests.
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) struct StoredSession {
pub id: String,
pub head_id: Option<String>,
pub server_session_id: Option<String>,
pub directory: Option<String>,
pub git_root: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub archived_at: Option<i64>,
}
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) struct StoredEvent {
pub id: String,
pub session_id: String,
pub parent_id: Option<String>,
pub invocation_id: String,
pub event_type: String,
pub event_data: String,
pub created_at: i64,
}
/// Row type returned by session queries (avoids clippy::type_complexity).
type SessionRow = (
String,
Option<String>,
Option<String>,
Option<String>,
Option<String>,
i64,
i64,
Option<i64>,
);
/// Row type returned by event queries.
type EventRow = (String, String, Option<String>, String, String, String, i64);
pub(crate) struct AiSessionStore {
pool: SqlitePool,
}
impl AiSessionStore {
pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
let path = path.as_ref();
let path_str = path
.as_os_str()
.to_str()
.ok_or_else(|| eyre!("AI session database path is not valid UTF-8: {path:?}"))?;
let is_memory = path_str.contains(":memory:");
if !is_memory
&& !path.exists()
&& let Some(dir) = path.parent()
{
fs_err::create_dir_all(dir)?;
}
let opts = SqliteConnectOptions::from_str(path_str)?
.journal_mode(SqliteJournalMode::Wal)
.optimize_on_close(true, None)
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.acquire_timeout(Duration::from_secs_f64(timeout))
.connect_with(opts)
.await?;
sqlx::migrate!("./migrations").run(&pool).await?;
#[cfg(unix)]
if !is_memory {
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?;
}
Ok(Self { pool })
}
pub async fn create_session(
&self,
id: &str,
directory: Option<&str>,
git_root: Option<&str>,
) -> Result<StoredSession> {
let now = OffsetDateTime::now_utc().unix_timestamp();
sqlx::query(
"INSERT INTO sessions (id, directory, git_root, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?4)",
)
.bind(id)
.bind(directory)
.bind(git_root)
.bind(now)
.execute(&self.pool)
.await?;
Ok(StoredSession {
id: id.to_string(),
head_id: None,
server_session_id: None,
directory: directory.map(String::from),
git_root: git_root.map(String::from),
created_at: now,
updated_at: now,
archived_at: None,
})
}
#[allow(dead_code)] // used in tests; will be used by daemon service
pub async fn get_session(&self, id: &str) -> Result<Option<StoredSession>> {
let row: Option<SessionRow> = sqlx::query_as(
"SELECT id, head_id, server_session_id, directory, git_root,
created_at, updated_at, archived_at
FROM sessions WHERE id = ?1",
)
.bind(id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(
|(
id,
head_id,
server_session_id,
directory,
git_root,
created_at,
updated_at,
archived_at,
)| {
StoredSession {
id,
head_id,
server_session_id,
directory,
git_root,
created_at,
updated_at,
archived_at,
}
},
))
}
/// Find the most recent non-archived session matching the given directory or git
/// root, updated within `max_age_secs` seconds.
pub async fn find_resumable_session(
&self,
directory: Option<&str>,
git_root: Option<&str>,
max_age_secs: i64,
) -> Result<Option<StoredSession>> {
let cutoff = OffsetDateTime::now_utc().unix_timestamp() - max_age_secs;
let row: Option<SessionRow> = sqlx::query_as(
"SELECT id, head_id, server_session_id, directory, git_root,
created_at, updated_at, archived_at
FROM sessions
WHERE archived_at IS NULL
AND updated_at > ?1
AND (directory = ?2 OR (git_root IS NOT NULL AND git_root = ?3))
ORDER BY updated_at DESC
LIMIT 1",
)
.bind(cutoff)
.bind(directory)
.bind(git_root)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(
|(
id,
head_id,
server_session_id,
directory,
git_root,
created_at,
updated_at,
archived_at,
)| {
StoredSession {
id,
head_id,
server_session_id,
directory,
git_root,
created_at,
updated_at,
archived_at,
}
},
))
}
/// Append a single event and update the session's `head_id` and `updated_at`.
pub async fn append_event(
&self,
session_id: &str,
event_id: &str,
parent_id: Option<&str>,
invocation_id: &str,
event_type: &str,
event_data: &str,
) -> Result<()> {
let now = OffsetDateTime::now_utc().unix_timestamp();
let mut tx = self.pool.begin().await?;
sqlx::query(
"INSERT INTO session_events (id, session_id, parent_id, invocation_id, event_type, event_data, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
)
.bind(event_id)
.bind(session_id)
.bind(parent_id)
.bind(invocation_id)
.bind(event_type)
.bind(event_data)
.bind(now)
.execute(&mut *tx)
.await?;
sqlx::query("UPDATE sessions SET head_id = ?1, updated_at = ?2 WHERE id = ?3")
.bind(event_id)
.bind(now)
.bind(session_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
/// Load all events for a session, ordered chronologically.
pub async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> {
let rows: Vec<EventRow> = sqlx::query_as(
"SELECT id, session_id, parent_id, invocation_id, event_type, event_data, created_at
FROM session_events
WHERE session_id = ?1
ORDER BY created_at ASC, rowid ASC",
)
.bind(session_id)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(
|(id, session_id, parent_id, invocation_id, event_type, event_data, created_at)| {
StoredEvent {
id,
session_id,
parent_id,
invocation_id,
event_type,
event_data,
created_at,
}
},
)
.collect())
}
pub async fn update_server_session_id(
&self,
session_id: &str,
server_session_id: &str,
) -> Result<()> {
sqlx::query("UPDATE sessions SET server_session_id = ?1 WHERE id = ?2")
.bind(server_session_id)
.bind(session_id)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn archive_session(&self, session_id: &str) -> Result<()> {
let now = OffsetDateTime::now_utc().unix_timestamp();
sqlx::query("UPDATE sessions SET archived_at = ?1 WHERE id = ?2")
.bind(now)
.bind(session_id)
.execute(&self.pool)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn new_test_store() -> AiSessionStore {
AiSessionStore::new("sqlite::memory:", 2.0).await.unwrap()
}
#[tokio::test]
async fn test_create_and_get_session() {
let store = new_test_store().await;
let session = store
.create_session("s1", Some("/home/user/project"), Some("/home/user/project"))
.await
.unwrap();
assert_eq!(session.id, "s1");
assert!(session.head_id.is_none());
assert!(session.archived_at.is_none());
let loaded = store.get_session("s1").await.unwrap().unwrap();
assert_eq!(loaded.id, "s1");
assert_eq!(loaded.directory.as_deref(), Some("/home/user/project"));
}
#[tokio::test]
async fn test_get_nonexistent_session() {
let store = new_test_store().await;
assert!(store.get_session("nope").await.unwrap().is_none());
}
#[tokio::test]
async fn test_append_and_load_events() {
let store = new_test_store().await;
store
.create_session("s1", Some("/tmp"), None)
.await
.unwrap();
store
.append_event(
"s1",
"e1",
None,
"inv1",
"user_message",
r#"{"content":"hello"}"#,
)
.await
.unwrap();
store
.append_event(
"s1",
"e2",
Some("e1"),
"inv1",
"text",
r#"{"content":"hi there"}"#,
)
.await
.unwrap();
let events = store.load_events("s1").await.unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].id, "e1");
assert!(events[0].parent_id.is_none());
assert_eq!(events[0].invocation_id, "inv1");
assert_eq!(events[1].id, "e2");
assert_eq!(events[1].parent_id.as_deref(), Some("e1"));
let session = store.get_session("s1").await.unwrap().unwrap();
assert_eq!(session.head_id.as_deref(), Some("e2"));
}
#[tokio::test]
async fn test_find_resumable_session() {
let store = new_test_store().await;
store
.create_session("s1", Some("/home/user/project"), None)
.await
.unwrap();
let found = store
.find_resumable_session(Some("/home/user/project"), None, 3600)
.await
.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().id, "s1");
}
#[tokio::test]
async fn test_find_resumable_by_git_root() {
let store = new_test_store().await;
store
.create_session(
"s1",
Some("/home/user/project/sub"),
Some("/home/user/project"),
)
.await
.unwrap();
let found = store
.find_resumable_session(Some("/different/dir"), Some("/home/user/project"), 3600)
.await
.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().id, "s1");
}
#[tokio::test]
async fn test_find_resumable_skips_archived() {
let store = new_test_store().await;
store
.create_session("s1", Some("/tmp"), None)
.await
.unwrap();
store.archive_session("s1").await.unwrap();
let found = store
.find_resumable_session(Some("/tmp"), None, 3600)
.await
.unwrap();
assert!(found.is_none());
}
#[tokio::test]
async fn test_find_resumable_no_match_different_dir() {
let store = new_test_store().await;
store
.create_session("s1", Some("/home/user/project"), None)
.await
.unwrap();
let found = store
.find_resumable_session(Some("/other/dir"), None, 3600)
.await
.unwrap();
assert!(found.is_none());
}
#[tokio::test]
async fn test_archive_session() {
let store = new_test_store().await;
store
.create_session("s1", Some("/tmp"), None)
.await
.unwrap();
store.archive_session("s1").await.unwrap();
let session = store.get_session("s1").await.unwrap().unwrap();
assert!(session.archived_at.is_some());
}
#[tokio::test]
async fn test_update_server_session_id() {
let store = new_test_store().await;
store
.create_session("s1", Some("/tmp"), None)
.await
.unwrap();
store
.update_server_session_id("s1", "server-abc")
.await
.unwrap();
let session = store.get_session("s1").await.unwrap().unwrap();
assert_eq!(session.server_session_id.as_deref(), Some("server-abc"));
}
#[tokio::test]
async fn test_find_resumable_does_not_mutate() {
let store = new_test_store().await;
store
.create_session("s1", Some("/tmp"), None)
.await
.unwrap();
let before = store.get_session("s1").await.unwrap().unwrap();
store
.find_resumable_session(Some("/tmp"), None, 3600)
.await
.unwrap()
.unwrap();
let after = store.get_session("s1").await.unwrap().unwrap();
assert_eq!(before.updated_at, after.updated_at);
}
#[tokio::test]
async fn test_events_ordered_chronologically() {
let store = new_test_store().await;
store
.create_session("s1", Some("/tmp"), None)
.await
.unwrap();
store
.append_event("s1", "e1", None, "inv1", "user_message", "{}")
.await
.unwrap();
store
.append_event("s1", "e2", Some("e1"), "inv1", "text", "{}")
.await
.unwrap();
store
.append_event("s1", "e3", Some("e2"), "inv2", "user_message", "{}")
.await
.unwrap();
let events = store.load_events("s1").await.unwrap();
assert_eq!(events.len(), 3);
assert!(events[0].created_at <= events[1].created_at);
assert!(events[1].created_at <= events[2].created_at);
assert_eq!(events[2].invocation_id, "inv2");
}
}
+9 -1
View File
@@ -12,6 +12,7 @@ use eye_declare::Handle;
use eyre::{Context, Result};
use futures::StreamExt;
use reqwest::Url;
use reqwest::header::USER_AGENT;
use crate::{
context::{AppContext, ClientContext},
@@ -19,6 +20,8 @@ use crate::{
tui::{Session, events::AiTuiEvent},
};
static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"));
/// Frames that alter the stream lifecycle — terminal or state-changing.
#[derive(Debug, Clone)]
pub(crate) enum StreamControl {
@@ -57,6 +60,7 @@ pub(crate) struct ChatRequest {
pub messages: Vec<serde_json::Value>,
pub session_id: Option<String>,
pub capabilities: Vec<String>,
pub invocation_id: String,
}
impl ChatRequest {
@@ -64,8 +68,9 @@ impl ChatRequest {
messages: Vec<serde_json::Value>,
session_id: Option<String>,
capabilities: &AiCapabilities,
invocation_id: String,
) -> Self {
let mut caps = vec![];
let mut caps = vec!["client_invocations".to_string()];
if capabilities.enable_history_search.unwrap_or(true) {
caps.push("client_v1_atuin_history".to_string());
}
@@ -82,6 +87,7 @@ impl ChatRequest {
messages,
session_id,
capabilities: caps,
invocation_id,
}
}
}
@@ -112,6 +118,7 @@ fn create_chat_stream(
"messages": request.messages,
"context": context,
"capabilities": request.capabilities,
"invocation_id": request.invocation_id
});
if let Some(ref sid) = request.session_id {
@@ -123,6 +130,7 @@ fn create_chat_stream(
let response = match client
.post(endpoint.clone())
.header("Accept", "text/event-stream")
.header(USER_AGENT, APP_USER_AGENT)
.bearer_auth(&token)
.json(&request_body)
.send()
@@ -19,7 +19,7 @@ use ratatui_core::{
};
use tui_textarea::TextArea;
use crate::tui::events::AiTuiEvent;
use crate::tui::{events::AiTuiEvent, slash::SlashCommandSearchResult};
/// A bordered text input box backed by tui-textarea.
///
@@ -35,6 +35,8 @@ pub(crate) struct InputBox {
pub footer: String,
/// Whether the input is currently active (shows cursor, accepts input)
pub active: bool,
/// If the user has typed a slash command, this holds the best match for it.
pub slash_suggestion: Option<SlashCommandSearchResult>,
}
pub(crate) struct InputBoxState {
@@ -129,6 +131,18 @@ fn input_box(
textarea.insert_newline();
return EventResult::Consumed;
}
crossterm::event::KeyCode::Tab if props.slash_suggestion.is_some() => {
// If there's a slash command suggestion, Tab accepts it.
if let Some(suggestion) = &props.slash_suggestion {
textarea.clear();
textarea.insert_str(format!("/{}", suggestion.command.name));
// Manually trigger an input update event so the slash suggestion box can update immediately
if let Some(ref tx) = state.tx {
let _ = tx.send(AiTuiEvent::InputUpdated(textarea.lines().join("\n")));
}
return EventResult::Consumed;
}
}
crossterm::event::KeyCode::Enter => {
if key.modifiers.contains(KeyModifiers::SHIFT) {
textarea.insert_newline();
@@ -2,3 +2,4 @@ pub(crate) mod atuin_ai;
pub(crate) mod input_box;
pub(crate) mod markdown;
pub(crate) mod select;
pub(crate) mod session_continue;
@@ -0,0 +1,49 @@
use chrono_humanize::HumanTime;
use eye_declare::{Elements, Hooks, Span, Text, component, element, props};
use ratatui::style::{Color, Modifier, Style};
#[props]
pub(crate) struct SessionContinue {
pub continued_at: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Default)]
pub(crate) struct SessionContinueState {
/// Frozen on mount so the label doesn't change on every render.
label: Option<String>,
}
#[component(props = SessionContinue, state = SessionContinueState)]
fn session_continue(
_props: &SessionContinue,
state: &SessionContinueState,
hooks: &mut Hooks<SessionContinue, SessionContinueState>,
) -> Elements {
hooks.use_mount(|props, state| {
state.label = Some(match props.continued_at {
Some(t) => {
let human = HumanTime::from(t - chrono::Utc::now());
format!(
" Continuing previous session (last active {human}) - type /new to start a new session"
)
}
None => {
" Continuing previous session - type /new to start a new session".to_string()
}
});
});
let resume_label = state
.label
.as_deref()
.unwrap_or(" Continuing previous session - type /new to start a new session");
element! {
Text {
Span(
text: resume_label,
style: Style::default().fg(Color::DarkGray).add_modifier(Modifier::ITALIC),
)
}
}
}
+3
View File
@@ -1,3 +1,6 @@
Welcome to Atuin AI, an AI assistant in your terminal. You can ask it to generate a shell command for you, or ask general terminal or software questions.
Commands:
{commands}
For more information, see [https://docs.atuin.sh/cli/ai/introduction/](https://docs.atuin.sh/cli/ai/introduction/)
+108 -9
View File
@@ -2,14 +2,16 @@ use std::path::PathBuf;
use std::sync::mpsc;
use crate::context::{AppContext, ClientContext};
use crate::context_window::ContextWindowBuilder;
use crate::permissions::check::PermissionResponse;
use crate::permissions::resolver::PermissionResolver;
use crate::permissions::rule::Rule;
use crate::permissions::writer::{self, RuleDisposition};
use crate::session::SessionManager;
use crate::stream::{ChatRequest, run_chat_stream};
use crate::tools::{ClientToolCall, ToolPhase};
use crate::tui::events::{AiTuiEvent, PermissionResult};
use crate::tui::state::{ExitAction, Session};
use crate::tui::state::{ConversationEvent, ExitAction, Session};
use eye_declare::Handle;
use tokio::task::JoinHandle;
@@ -19,6 +21,7 @@ pub(crate) fn dispatch(
tx: &mpsc::Sender<AiTuiEvent>,
app_ctx: &AppContext,
client_ctx: &ClientContext,
session_mgr: &mut SessionManager,
) {
match event {
AiTuiEvent::ContinueAfterTools => {
@@ -28,7 +31,7 @@ pub(crate) fn dispatch(
on_input_updated(handle, input);
}
AiTuiEvent::SubmitInput(input) => {
on_submit_input(handle, tx, app_ctx, client_ctx, input);
on_submit_input(handle, tx, app_ctx, client_ctx, input, session_mgr);
}
AiTuiEvent::SlashCommand(cmd) => {
on_slash_command(handle, cmd);
@@ -61,6 +64,35 @@ pub(crate) fn dispatch(
on_exit(handle);
}
}
// Persist any new conversation events after each dispatch cycle.
persist_session(handle, session_mgr);
}
/// Persist new events and the server session ID if it has changed.
/// Called from the dispatch thread (sync), bridges to async via the tokio handle.
fn persist_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) {
let Ok((events, server_sid)) = handle
.fetch(|state| {
(
state.conversation.events.clone(),
state.conversation.session_id.clone(),
)
})
.blocking_recv()
else {
return;
};
let rt = tokio::runtime::Handle::current();
if let Err(e) = rt.block_on(session_mgr.persist_events(&events)) {
tracing::warn!("failed to persist session events: {e}");
}
if let Some(ref sid) = server_sid
&& let Err(e) = rt.block_on(session_mgr.persist_server_session_id(sid))
{
tracing::warn!("failed to persist server session ID: {e}");
}
}
fn launch_stream(
@@ -78,9 +110,10 @@ fn launch_stream(
handle.update(move |state| {
(setup)(state);
state.start_streaming();
let messages = state.conversation.events_to_messages();
let messages =
ContextWindowBuilder::with_default_budget().build(&state.conversation.events);
let sid = state.conversation.session_id.clone();
let request = ChatRequest::new(messages, sid, &caps);
let request = ChatRequest::new(messages, sid, &caps, state.invocation_id.clone());
let task: JoinHandle<()> = tokio::spawn(async move {
run_chat_stream(h2, tx2, app, cc, request).await;
});
@@ -98,10 +131,30 @@ fn on_continue_after_tools(
}
fn on_input_updated(handle: &Handle<Session>, input: String) {
let input_blank = input.trim().is_empty();
let input_blank = input.is_empty();
let slash_command = if input.starts_with('/') {
Some(input.trim_start_matches('/').to_string())
} else {
None
};
handle.update(move |state| {
state.interaction.is_input_blank = input_blank;
state.interaction.slash_command_input = slash_command;
if let Some(query) = state.interaction.slash_command_input.as_ref() {
let mut results = state.slash_registry.search_fuzzy(query);
results.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
state.interaction.slash_command_search_results = results;
} else {
state.interaction.slash_command_search_results.clear();
}
});
}
@@ -111,7 +164,13 @@ fn on_submit_input(
app_ctx: &AppContext,
client_ctx: &ClientContext,
input: String,
session_mgr: &mut SessionManager,
) {
handle.update(move |state| {
state.interaction.slash_command_input = None;
state.interaction.slash_command_search_results.clear();
});
let input = input.trim().to_string();
if input.is_empty() {
let h2 = handle.clone();
@@ -129,9 +188,15 @@ fn on_submit_input(
}
if input.starts_with('/') {
handle.update(move |state| {
state.conversation.handle_slash_command(&input);
});
if input.trim() == "/new" {
on_new_session(handle, session_mgr);
} else {
handle.update(move |state| {
state
.conversation
.handle_slash_command(&input, &state.slash_registry);
});
}
return;
}
@@ -144,7 +209,9 @@ fn on_submit_input(
fn on_slash_command(handle: &Handle<Session>, command: String) {
handle.update(move |state| {
state.conversation.handle_slash_command(&command);
state
.conversation
.handle_slash_command(&command, &state.slash_registry);
});
}
@@ -533,6 +600,38 @@ fn on_retry(
});
}
fn on_new_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) {
let rt = tokio::runtime::Handle::current();
if let Err(e) = rt.block_on(session_mgr.archive_and_reset()) {
tracing::warn!("failed to start new session: {e}");
return;
}
handle.update(|state| {
// Move the current invocation's visible events to the archived view
// so they remain on screen but are no longer sent to the API.
let visible_events: Vec<ConversationEvent> =
state.conversation.events[state.view_start_index..].to_vec();
state.archived_view_events.extend(visible_events);
state.conversation.events.clear();
state.conversation.session_id = None;
state.tool_tracker = crate::tools::ToolTracker::new();
state.view_start_index = 0;
state.is_resumed = false;
state.last_event_time = None;
state
.conversation
.events
.push(ConversationEvent::OutOfBandOutput {
name: "System".to_string(),
command: Some("/new".to_string()),
content: "Started a new session.".to_string(),
});
});
}
fn on_exit(handle: &Handle<Session>) {
let h2 = handle.clone();
handle.update(move |state| {
+2 -1
View File
@@ -1,7 +1,8 @@
pub(crate) mod components;
pub(crate) mod dispatch;
pub(crate) mod events;
pub(crate) mod slash;
pub(crate) mod state;
pub(crate) mod view;
pub(crate) use state::{ConversationEvent, Session};
pub(crate) use state::{ConversationEvent, Session, events_to_messages};
+79
View File
@@ -0,0 +1,79 @@
#[derive(Debug, Clone)]
pub(crate) struct SlashCommand {
pub name: String,
pub description: String,
}
impl SlashCommand {
pub fn new(name: &str, description: &str) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
}
}
}
#[derive(Debug)]
pub(crate) struct SlashCommandRegistry {
commands: Vec<SlashCommand>,
}
#[derive(Debug, Clone)]
pub(crate) struct SlashCommandSearchResult {
pub command: SlashCommand,
pub relevance: f32,
pub span: (usize, usize),
}
impl SlashCommandRegistry {
pub fn new() -> Self {
Self {
commands: Vec::new(),
}
}
pub fn register(&mut self, command: SlashCommand) {
self.commands.push(command);
}
pub fn get_commands(&self) -> &[SlashCommand] {
&self.commands
}
pub fn search_fuzzy(&self, query: &str) -> Vec<SlashCommandSearchResult> {
let query_lower = query.to_lowercase();
self.commands
.iter()
.filter_map(|command| {
let name_lower = command.name.to_lowercase();
if let Some(start) = name_lower.find(&query_lower as &str) {
let end = start + query_lower.len();
Some((command, start, end))
} else {
None
}
})
.map(|(command, start, end)| {
SlashCommandSearchResult {
command: command.clone(),
relevance: 1.0, // Simple relevance score for now
span: (start, end),
}
})
.collect()
}
}
impl Default for SlashCommandRegistry {
fn default() -> Self {
let mut registry = Self::new();
registry.register(SlashCommand::new("help", "Show help information"));
registry.register(SlashCommand::new(
"new",
"Start a new conversation, archiving the current one",
));
registry
}
}
+203 -134
View File
@@ -5,7 +5,10 @@
use tokio::task::AbortHandle;
use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker};
use crate::{
tools::{ClientToolCall, ToolOutcome, ToolTracker},
tui::slash::{SlashCommandRegistry, SlashCommandSearchResult},
};
/// Streaming status indicators from server
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -57,9 +60,25 @@ pub(crate) enum ConversationEvent {
command: Option<String>,
content: String,
},
/// Context injected for the LLM that is not rendered in the TUI.
/// Converted to a user message in the API protocol.
SystemContext { content: String },
}
impl ConversationEvent {
/// Whether this event represents actual conversation content sent to the API.
/// Used to determine if a resumed session has meaningful context.
pub(crate) fn is_api_content(&self) -> bool {
match self {
ConversationEvent::UserMessage { .. } => true,
ConversationEvent::Text { .. } => true,
ConversationEvent::ToolCall { .. } => true,
ConversationEvent::ToolResult { .. } => true,
ConversationEvent::OutOfBandOutput { .. } => false,
ConversationEvent::SystemContext { .. } => false,
}
}
/// Extract command from a suggest_command tool call
pub(crate) fn as_command(&self) -> Option<&str> {
if let ConversationEvent::ToolCall { name, input, .. } = self
@@ -111,131 +130,6 @@ impl Conversation {
}
}
/// Convert conversation events to Claude API message format
pub fn events_to_messages(&self) -> Vec<serde_json::Value> {
let mut messages = Vec::new();
let mut i = 0;
let events = &self.events;
while i < events.len() {
match &events[i] {
ConversationEvent::UserMessage { content } => {
messages.push(serde_json::json!({
"role": "user",
"content": content
}));
i += 1;
}
ConversationEvent::Text { content } => {
// Check if the next event(s) are ToolCalls — if so, combine
// into a single assistant message with mixed content blocks.
let next_is_tool_call = events
.get(i + 1)
.is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. }));
if next_is_tool_call {
let mut content_blocks = Vec::new();
if !content.is_empty() {
content_blocks.push(serde_json::json!({
"type": "text",
"text": content
}));
}
while let Some(ConversationEvent::ToolCall {
id, name, input, ..
}) = events.get(i + 1)
{
content_blocks.push(serde_json::json!({
"type": "tool_use",
"id": id,
"name": name,
"input": input
}));
i += 1;
}
messages.push(serde_json::json!({
"role": "assistant",
"content": content_blocks
}));
i += 1;
} else {
messages.push(serde_json::json!({
"role": "assistant",
"content": content
}));
i += 1;
}
}
ConversationEvent::ToolCall { .. } => {
// ToolCalls without preceding Text (shouldn't normally happen,
// but handle defensively)
let mut tool_uses = Vec::new();
while i < events.len() {
if let ConversationEvent::ToolCall {
id, name, input, ..
} = &events[i]
{
tool_uses.push(serde_json::json!({
"type": "tool_use",
"id": id,
"name": name,
"input": input
}));
i += 1;
} else {
break;
}
}
messages.push(serde_json::json!({
"role": "assistant",
"content": tool_uses
}));
}
ConversationEvent::ToolResult {
tool_use_id,
content,
is_error,
remote,
content_length,
} => {
let tool_result = if *remote {
let mut obj = serde_json::json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"remote": true,
"is_error": is_error
});
if let Some(len) = content_length {
obj["content_length"] = serde_json::json!(len);
}
obj
} else {
serde_json::json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
"is_error": is_error
})
};
messages.push(serde_json::json!({
"role": "user",
"content": [tool_result]
}));
i += 1;
}
ConversationEvent::OutOfBandOutput { .. } => {
// Out-of-band output is not sent to the server, so we don't need to add it to the messages
i += 1;
}
}
}
messages
}
/// Get the most recent command from events
pub fn current_command(&self) -> Option<&str> {
self.events.iter().rev().find_map(|e| e.as_command())
@@ -343,15 +237,22 @@ impl Conversation {
}
/// Handle a slash command
pub fn handle_slash_command(&mut self, command: &str) {
pub fn handle_slash_command(&mut self, command: &str, registry: &SlashCommandRegistry) {
match command.trim() {
"/help" => {
let content = include_str!("./content/help.md");
let commands = registry
.get_commands()
.iter()
.map(|cmd| format!("- `/{}` - {}", cmd.name, cmd.description))
.collect::<Vec<_>>()
.join("\n");
let content = include_str!("./content/help.md").replace("{commands}", &commands);
self.events.push(ConversationEvent::OutOfBandOutput {
name: "System".to_string(),
command: Some("/help".to_string()),
content: content.to_string(),
content,
});
}
_ => self.events.push(ConversationEvent::OutOfBandOutput {
@@ -363,6 +264,147 @@ impl Conversation {
}
}
/// Convert a slice of conversation events to Claude API message format.
///
/// This is the canonical event-to-message conversion, used by the context window
/// builder to convert turn slices independently. The logic handles combining
/// adjacent Text + ToolCall events into single assistant messages with mixed
/// content blocks.
pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec<serde_json::Value> {
let mut messages = Vec::new();
let mut i = 0;
while i < events.len() {
match &events[i] {
ConversationEvent::UserMessage { content } => {
messages.push(serde_json::json!({
"role": "user",
"content": content
}));
i += 1;
}
ConversationEvent::Text { content } if content.is_empty() => {
// Skip empty text events (e.g. streaming buffer before
// any data arrived).
i += 1;
}
ConversationEvent::Text { content } => {
// Check if the next event(s) are ToolCalls — if so, combine
// into a single assistant message with mixed content blocks.
let next_is_tool_call = events
.get(i + 1)
.is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. }));
if next_is_tool_call {
let mut content_blocks = Vec::new();
if !content.is_empty() {
content_blocks.push(serde_json::json!({
"type": "text",
"text": content
}));
}
while let Some(ConversationEvent::ToolCall {
id, name, input, ..
}) = events.get(i + 1)
{
content_blocks.push(serde_json::json!({
"type": "tool_use",
"id": id,
"name": name,
"input": input
}));
i += 1;
}
messages.push(serde_json::json!({
"role": "assistant",
"content": content_blocks
}));
i += 1;
} else {
messages.push(serde_json::json!({
"role": "assistant",
"content": content
}));
i += 1;
}
}
ConversationEvent::ToolCall { .. } => {
// ToolCalls without preceding Text (shouldn't normally happen,
// but handle defensively)
let mut tool_uses = Vec::new();
while i < events.len() {
if let ConversationEvent::ToolCall {
id, name, input, ..
} = &events[i]
{
tool_uses.push(serde_json::json!({
"type": "tool_use",
"id": id,
"name": name,
"input": input
}));
i += 1;
} else {
break;
}
}
messages.push(serde_json::json!({
"role": "assistant",
"content": tool_uses
}));
}
ConversationEvent::ToolResult {
tool_use_id,
content,
is_error,
remote,
content_length,
} => {
let tool_result = if *remote {
let mut obj = serde_json::json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"remote": true,
"is_error": is_error
});
if let Some(len) = content_length {
obj["content_length"] = serde_json::json!(len);
}
obj
} else {
serde_json::json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
"is_error": is_error
})
};
messages.push(serde_json::json!({
"role": "user",
"content": [tool_result]
}));
i += 1;
}
ConversationEvent::OutOfBandOutput { .. } => {
// Out-of-band output is not sent to the server
i += 1;
}
ConversationEvent::SystemContext { content } => {
messages.push(serde_json::json!({
"role": "user",
"content": content
}));
i += 1;
}
}
}
messages
}
/// Ephemeral UI/presentation state
#[derive(Debug)]
pub(crate) struct Interaction {
@@ -370,6 +412,10 @@ pub(crate) struct Interaction {
pub mode: AppMode,
/// Whether the input is blank
pub is_input_blank: bool,
/// The currently in-progress slash command (if any)
pub slash_command_input: Option<String>,
/// Search results for the current slash command input
pub slash_command_search_results: Vec<SlashCommandSearchResult>,
/// True when user has pressed Enter once on a dangerous command
pub confirmation_pending: bool,
/// Current streaming status
@@ -385,6 +431,8 @@ impl Interaction {
Self {
mode: AppMode::Input,
is_input_blank: false,
slash_command_input: None,
slash_command_search_results: Vec::new(),
confirmation_pending: false,
streaming_status: None,
was_interrupted: false,
@@ -410,10 +458,26 @@ pub(crate) struct Session {
pub exit_action: Option<ExitAction>,
/// Abort handle for the active streaming task, if any
pub stream_abort: Option<AbortHandle>,
/// Index into `conversation.events` where the current TUI invocation starts.
/// Events before this index are historical context sent to the API but not
/// rendered in the TUI.
pub view_start_index: usize,
/// Whether this session was resumed from a prior invocation.
pub is_resumed: bool,
/// Time of the last event from a previous invocation when resuming a session
pub last_event_time: Option<chrono::DateTime<chrono::Utc>>,
/// Events from archived sessions that are still rendered on screen but no
/// longer sent to the API. Accumulated by `/new` commands within a single
/// TUI lifetime.
pub archived_view_events: Vec<ConversationEvent>,
/// A registry of available slash commands
pub slash_registry: SlashCommandRegistry,
/// The unique ID for this invocation
pub invocation_id: String,
}
impl Session {
pub fn new(in_git_project: bool) -> Self {
pub fn new(in_git_project: bool, invocation_id: Option<String>) -> Self {
Self {
conversation: Conversation::new(),
interaction: Interaction::new(),
@@ -421,6 +485,12 @@ impl Session {
in_git_project,
exit_action: None,
stream_abort: None,
view_start_index: 0,
is_resumed: false,
last_event_time: None,
archived_view_events: Vec::new(),
slash_registry: Default::default(),
invocation_id: invocation_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()),
}
}
@@ -455,11 +525,10 @@ impl Session {
// ===== Streaming lifecycle methods =====
/// Start streaming response.
/// Pushes an empty Text event that will be mutated in-place as chunks arrive.
/// The Text event for streamed content is created lazily by
/// `append_streaming_text` when the first chunk arrives, so we
/// don't leave an empty assistant turn in the conversation.
pub fn start_streaming(&mut self) {
self.conversation.events.push(ConversationEvent::Text {
content: String::new(),
});
self.interaction.streaming_status = None;
self.interaction.was_interrupted = false;
self.interaction.mode = AppMode::Streaming;
+35 -2
View File
@@ -8,6 +8,7 @@ use ratatui_core::style::{Color, Modifier, Style};
use crate::tools::{ClientToolCall, TrackedTool};
use crate::tui::components::select::SelectOption;
use crate::tui::components::session_continue::SessionContinue;
use crate::tui::events::{AiTuiEvent, PermissionResult};
use super::components::atuin_ai::AtuinAi;
@@ -29,7 +30,10 @@ mod turn;
pub(crate) fn ai_view(state: &Session) -> Elements {
let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker);
for event in &state.conversation.events {
for event in &state.archived_view_events {
turn_builder.add_event(event);
}
for event in &state.conversation.events[state.view_start_index..] {
turn_builder.add_event(event);
}
let turns = turn_builder.build();
@@ -46,6 +50,10 @@ pub(crate) fn ai_view(state: &Session) -> Elements {
pending_confirmation: state.interaction.confirmation_pending,
has_executing_preview: state.tool_tracker.has_executing_preview(),
) {
#(if state.is_resumed && (!state.is_exiting() || !turns.is_empty()) {
SessionContinue(key: "continuation-notice", continued_at: state.last_event_time)
})
#(for (index, turn) in turns.iter().enumerate() {
#(match turn {
turn::UiTurn::User { events } => {
@@ -70,6 +78,13 @@ pub(crate) fn ai_view(state: &Session) -> Elements {
fn input_view(state: &Session) -> Elements {
let asking_tool = state.tool_tracker.asking_for_permission();
let in_git_project = state.in_git_project;
let slash_results = state
.interaction
.slash_command_search_results
.iter()
.take(4)
.collect::<Vec<_>>();
let first_slash_result = slash_results.first().cloned();
element! {
#(if let Some(tc) = asking_tool {
@@ -84,6 +99,7 @@ fn input_view(state: &Session) -> Elements {
title_right: "Atuin AI",
footer: state.footer_text(),
active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending,
slash_suggestion: first_slash_result.cloned()
)
#(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input {
@@ -93,6 +109,23 @@ fn input_view(state: &Session) -> Elements {
Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) }
})
})
#(if !slash_results.is_empty() {
#(for (i, result) in slash_results.iter().enumerate() {
Text {
Span(text: format!("/{}", &result.command.name[..result.span.0]), style: Style::default().fg(Color::Blue))
Span(text: &result.command.name[result.span.0..result.span.1], style: Style::default().fg(Color::Blue).add_modifier(Modifier::UNDERLINED))
Span(text: format!("{}", &result.command.name[result.span.1..]), style: Style::default().fg(Color::Blue))
Span(text: " - ")
Span(text: &result.command.description)
#(if i == 0 {
Span(text: " [Tab] Insert", style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC).dim())
})
}
})
})
}
})
}
@@ -270,7 +303,7 @@ fn out_of_band_turn_view(events: &[turn::UiEvent]) -> Elements {
element! {
View {
Text {
Span(text: "System", style: Style::default().fg(Color::Blue).add_modifier(Modifier::BOLD))
Span(text: " System ", style: Style::default().fg(Color::Blue).add_modifier(Modifier::BOLD).add_modifier(Modifier::REVERSED))
}
#(for event in events {
#(match event {
+3
View File
@@ -170,6 +170,9 @@ impl<'a> TurnBuilder<'a> {
} => {
self.add_out_of_band_output(name, command.as_deref(), content);
}
ConversationEvent::SystemContext { .. } => {
// Not rendered in the TUI — only sent to the API
}
}
}
+9
View File
@@ -664,6 +664,12 @@ pub struct Ai {
/// Only necessary for custom AI endpoints.
pub api_token: Option<String>,
/// Path to the AI sessions database.
pub db_path: String,
/// The maximum time in minutes that an AI session can be automatically resumed.
pub session_continue_minutes: i64,
/// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility.
#[serde(default)]
pub send_cwd: Option<bool>,
@@ -1467,6 +1473,7 @@ impl Settings {
let record_store_path = data_dir.join("records.db");
let kv_path = data_dir.join("kv.db");
let scripts_path = data_dir.join("scripts.db");
let ai_sessions_path = data_dir.join("ai_sessions.db");
let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock");
let pidfile_path = data_dir.join("atuin-daemon.pid");
let logs_dir = atuin_common::utils::logs_dir();
@@ -1550,6 +1557,8 @@ impl Settings {
.set_default("search.frequency_score_multiplier", 1.0)?
.set_default("search.frecency_score_multiplier", 1.0)?
.set_default("meta.db_path", meta_path.to_str())?
.set_default("ai.db_path", ai_sessions_path.to_str())?
.set_default("ai.session_continue_minutes", 60)?
.set_default("ai.send_cwd", false)?
.set_default("ai.opening.send_cwd", false)?
.set_default("ai.opening.send_last_command", false)?
+13 -1
View File
@@ -8,6 +8,18 @@ Default: `false`
Whether or not the AI feature are enabled. When set to `false`, the question mark keybinding will output a message with instructions to run `atuin setup` to enable the feature.
### db_path
Default: `ai_sessions.db` in the Atuin data directory.
The path to the SQLite database where Atuin AI sessions are stored.
### session_continue_minutes
Default: `60` (minutes)
The amount of time after the last interaction with Atuin AI that a session is considered "recent" and can be automatically continued. If you interact with Atuin AI and then invoke it again within this time window, the second interaction will be part of the same session. If you wait longer than this time window, a new session will be started. You can always start a new session manually by using the `/new` slash command in the Atuin AI interface.
### endpoint
Default: `null`
@@ -22,7 +34,7 @@ The API token for the Atuin AI endpoint. Used for AI features like command gener
## Capabilities
Settings that control what capabilities are sent to the LLM. These are specified under `[ai.capabilities]`.
Settings that control what capabilities are sent to the LLM, which the LLM uses to understand what features the client has available. These are specified under `[ai.capabilities]`.
### enable_history_search