mirror of
https://github.com/atuinsh/atuin.git
synced 2026-06-01 18:47:18 +02:00
feat: Client-tool execution + permission system (#3370)
Adds client-side tool execution to Atuin AI, starting with `atuin_history`. The server can request tool calls, which are executed locally with a permission system, and results are sent back to continue the conversation.
This commit is contained in:
@@ -13,3 +13,5 @@ ui/backend/target
|
||||
ui/backend/gen
|
||||
|
||||
sqlite-server.db*
|
||||
|
||||
.atuin/permissions.*.toml
|
||||
|
||||
Generated
+97
-36
@@ -280,21 +280,33 @@ dependencies = [
|
||||
"eye_declare",
|
||||
"eyre",
|
||||
"futures",
|
||||
"glob-match",
|
||||
"pretty_assertions",
|
||||
"pulldown-cmark",
|
||||
"ratatui",
|
||||
"ratatui-core",
|
||||
"ratatui-widgets",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
"tokio",
|
||||
"toml",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
"tree-sitter-fish",
|
||||
"tui-textarea-2",
|
||||
"typed-builder 0.18.2",
|
||||
"unicode-width 0.2.2",
|
||||
"uuid",
|
||||
"vt100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -954,7 +966,7 @@ dependencies = [
|
||||
"pathdiff",
|
||||
"serde_core",
|
||||
"toml",
|
||||
"winnow",
|
||||
"winnow 0.7.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1498,9 +1510,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "eye_declare"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cac5c1a3b194e6674e9e44dbfb035f31c4df7a1ff6c8765181c50e8482bb393a"
|
||||
checksum = "f9abe8051754adccf30ac4a0d54ce083a645fee7d4fc6c78d9d9770821bad45d"
|
||||
dependencies = [
|
||||
"crossterm",
|
||||
"eye_declare_macros",
|
||||
@@ -1514,9 +1526,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "eye_declare_macros"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98595776b5e10c6ea519c09940fb7995b64da1e9a70cc94aa6c08b3bd404925a"
|
||||
checksum = "39251ef16365f347032ab2344ad806f64d59f29fb171b4bafd05595fbda2604d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1848,6 +1860,12 @@ dependencies = [
|
||||
"wasip3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob-match"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9985c9503b412198aa4197559e9a318524ebc4519c229bfa05a535828c950b9d"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.13"
|
||||
@@ -4302,6 +4320,7 @@ version = "1.0.149"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
|
||||
dependencies = [
|
||||
"indexmap 2.13.0",
|
||||
"itoa",
|
||||
"memchr",
|
||||
"serde",
|
||||
@@ -4332,9 +4351,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "1.0.4"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776"
|
||||
checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
@@ -4804,6 +4823,12 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520"
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.5"
|
||||
@@ -5197,22 +5222,24 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "1.0.6+spec-1.1.0"
|
||||
version = "1.1.1+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "399b1124a3c9e16766831c6bba21e50192572cdd98706ea114f9502509686ffc"
|
||||
checksum = "994b95d9e7bae62b34bab0e2a4510b801fa466066a6a8b2b57361fa1eba068ee"
|
||||
dependencies = [
|
||||
"indexmap 2.13.0",
|
||||
"serde_core",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"toml_parser",
|
||||
"winnow",
|
||||
"toml_writer",
|
||||
"winnow 1.0.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "1.0.0+spec-1.1.0"
|
||||
version = "1.1.1+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e"
|
||||
checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
@@ -5227,23 +5254,23 @@ dependencies = [
|
||||
"toml_datetime",
|
||||
"toml_parser",
|
||||
"toml_writer",
|
||||
"winnow",
|
||||
"winnow 0.7.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.9+spec-1.1.0"
|
||||
version = "1.1.1+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
|
||||
checksum = "39ca317ebc49f06bd748bfba29533eac9485569dc9bf80b849024b025e814fb9"
|
||||
dependencies = [
|
||||
"winnow",
|
||||
"winnow 1.0.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.6+spec-1.1.0"
|
||||
version = "1.1.1+spec-1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607"
|
||||
checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db"
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
@@ -5473,6 +5500,46 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.26.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "887bd495d0582c5e3e0d8ece2233666169fa56a9644d172fc22ad179ab2d0538"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
"regex-syntax",
|
||||
"serde_json",
|
||||
"streaming-iterator",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-bash"
|
||||
version = "0.25.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e5ec769279cc91b561d3df0d8a5deb26b0ad40d183127f409494d6d8fc53062"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-fish"
|
||||
version = "3.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "014e3b299f251e9c2e372e3b5e1b0323ef21196e9aa2e90a5bc1f6130cbe8b18"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-language"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782"
|
||||
|
||||
[[package]]
|
||||
name = "tree_magic_mini"
|
||||
version = "3.2.2"
|
||||
@@ -5709,35 +5776,23 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "vt100"
|
||||
version = "0.15.2"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84cd863bf0db7e392ba3bd04994be3473491b31e66340672af5d11943c6274de"
|
||||
checksum = "054ff75fb8fa83e609e685106df4faeffdf3a735d3c74ebce97ec557d5d36fd9"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"log",
|
||||
"unicode-width 0.1.14",
|
||||
"unicode-width 0.2.2",
|
||||
"vte",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vte"
|
||||
version = "0.11.1"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f5022b5fbf9407086c180e9557be968742d839e68346af7792b8592489732197"
|
||||
checksum = "a5924018406ce0063cd67f8e008104968b74b563ee1b85dde3ed1f7cb87d3dbd"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"utf8parse",
|
||||
"vte_generate_state_changes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vte_generate_state_changes"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e369bee1b05d510a7b4ed645f5faa90619e05437111783ea5848f28d97d3c2e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6549,6 +6604,12 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5"
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
||||
|
||||
+9
-1
@@ -1,5 +1,9 @@
|
||||
[workspace]
|
||||
members = ["crates/*", "crates/atuin-nucleo/matcher", "crates/atuin-nucleo/bench"]
|
||||
members = [
|
||||
"crates/*",
|
||||
"crates/atuin-nucleo/matcher",
|
||||
"crates/atuin-nucleo/bench",
|
||||
]
|
||||
|
||||
resolver = "2"
|
||||
exclude = ["ui/backend", "crates/atuin-nucleo/matcher/fuzz"]
|
||||
@@ -65,6 +69,10 @@ rustls = { version = "0.23", default-features = false, features = [
|
||||
"std",
|
||||
"tls12",
|
||||
] }
|
||||
glob-match = "0.2.1"
|
||||
vt100 = "0.16"
|
||||
regex = "1.10.5"
|
||||
toml_edit = "0.25.4"
|
||||
|
||||
[workspace.dependencies.tracing-subscriber]
|
||||
version = "0.3"
|
||||
|
||||
@@ -12,6 +12,10 @@ repository = { workspace = true }
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"]
|
||||
|
||||
[dependencies]
|
||||
atuin-client = { workspace = true }
|
||||
atuin-common = { workspace = true }
|
||||
@@ -39,9 +43,21 @@ async-stream = "0.3"
|
||||
uuid = { workspace = true }
|
||||
tui-textarea-2 = "0.10.2"
|
||||
unicode-width = "0.2"
|
||||
eye_declare = "0.3"
|
||||
eye_declare = "0.4"
|
||||
ratatui-core = "0.1"
|
||||
ratatui-widgets = "0.3"
|
||||
thiserror = { workspace = true }
|
||||
glob-match = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
time = { workspace = true }
|
||||
toml = "1.1"
|
||||
toml_edit = { workspace = true }
|
||||
tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true }
|
||||
tree-sitter-bash = { version = "0.25.1", optional = true }
|
||||
tree-sitter-fish = { version = "3.6.0", optional = true }
|
||||
typed-builder = { workspace = true }
|
||||
vt100 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -9,7 +9,7 @@ use eyre::Result;
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
pub mod init;
|
||||
pub mod inline;
|
||||
pub(crate) mod inline;
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
pub struct AiArgs {
|
||||
@@ -71,7 +71,7 @@ pub async fn run(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_shell() -> Option<String> {
|
||||
pub(crate) fn detect_shell() -> Option<String> {
|
||||
Some(Shell::current().to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::commands::detect_shell;
|
||||
|
||||
pub async fn run(shell: String) -> eyre::Result<()> {
|
||||
pub(crate) async fn run(shell: String) -> eyre::Result<()> {
|
||||
let integration = match shell.as_str() {
|
||||
"zsh" => generate_zsh_integration(),
|
||||
"bash" => generate_bash_integration(),
|
||||
|
||||
@@ -1,38 +1,44 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc;
|
||||
|
||||
use crate::commands::detect_shell;
|
||||
use crate::context::{AppContext, ClientContext};
|
||||
use crate::tui::dispatch;
|
||||
use crate::tui::events::AiTuiEvent;
|
||||
use crate::tui::state::{AppState, ExitAction};
|
||||
use crate::tui::state::{ExitAction, Session};
|
||||
use crate::tui::view::ai_view;
|
||||
use atuin_client::database::{Database, Sqlite};
|
||||
use atuin_client::distro::detect_linux_distribution;
|
||||
use atuin_common::tls::ensure_crypto_provider;
|
||||
use eventsource_stream::Eventsource;
|
||||
use eye_declare::{Application, CtrlCBehavior, Handle};
|
||||
use eye_declare::{Application, CtrlCBehavior};
|
||||
use eyre::{Context as _, Result, bail};
|
||||
use futures::StreamExt;
|
||||
use reqwest::Url;
|
||||
use tracing::{debug, error, info, trace};
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub async fn run(
|
||||
pub(crate) async fn run(
|
||||
initial_command: Option<String>,
|
||||
api_endpoint: Option<String>,
|
||||
api_token: Option<String>,
|
||||
settings: &atuin_client::settings::Settings,
|
||||
output_for_hook: bool,
|
||||
) -> Result<()> {
|
||||
if !settings.ai.enabled.unwrap_or(false) {
|
||||
emit_shell_result(
|
||||
Action::Print(
|
||||
"Atuin AI is not enabled. Please enable it in your settings or run `atuin setup`."
|
||||
.to_string(),
|
||||
),
|
||||
output_for_hook,
|
||||
);
|
||||
if settings.ai.enabled == Some(false) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if settings.ai.enabled.is_none() {
|
||||
match prompt_ai_setup()? {
|
||||
SetupChoice::EnableAi => {
|
||||
set_ai_enabled(true).await?;
|
||||
}
|
||||
SetupChoice::DisableKeybind => {
|
||||
set_ai_enabled(false).await?;
|
||||
emit_shell_result(Action::Cancel, output_for_hook);
|
||||
return Ok(());
|
||||
}
|
||||
SetupChoice::Cancel => {
|
||||
emit_shell_result(Action::Cancel, output_for_hook);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = api_endpoint.as_deref().unwrap_or(
|
||||
settings
|
||||
.ai
|
||||
@@ -48,7 +54,36 @@ pub async fn run(
|
||||
ensure_hub_session(settings).await?
|
||||
};
|
||||
|
||||
let action = run_inline_tui(endpoint.to_string(), token, initial_command, settings).await?;
|
||||
let history_db_path = PathBuf::from(settings.db_path.as_str());
|
||||
let history_db = Sqlite::new(history_db_path, settings.local_timeout)
|
||||
.await
|
||||
.context("failed to open history database for AI")?;
|
||||
|
||||
// Support both legacy [ai] send_cwd and new [ai.opening] send_cwd
|
||||
let send_cwd =
|
||||
settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false);
|
||||
|
||||
let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) {
|
||||
history_db.last().await.ok().flatten().map(|h| h.command)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let git_root = std::env::current_dir()
|
||||
.ok()
|
||||
.and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?));
|
||||
|
||||
let ctx = AppContext {
|
||||
endpoint: endpoint.to_string(),
|
||||
token,
|
||||
send_cwd,
|
||||
last_command,
|
||||
history_db: std::sync::Arc::new(history_db),
|
||||
git_root,
|
||||
capabilities: settings.ai.capabilities.clone(),
|
||||
};
|
||||
|
||||
let action = run_inline_tui(ctx, initial_command).await?;
|
||||
emit_shell_result(action, output_for_hook);
|
||||
|
||||
Ok(())
|
||||
@@ -69,7 +104,7 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu
|
||||
if will_sync {
|
||||
println!(
|
||||
"Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing."
|
||||
)
|
||||
);
|
||||
}
|
||||
println!(
|
||||
"If you have an existing Atuin sync account, you can log in with your existing credentials."
|
||||
@@ -110,280 +145,17 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// SSE streaming
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum ChatStreamEvent {
|
||||
TextChunk(String),
|
||||
ToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
is_error: bool,
|
||||
},
|
||||
Status(String),
|
||||
Done {
|
||||
session_id: String,
|
||||
},
|
||||
Error(String),
|
||||
}
|
||||
|
||||
fn create_chat_stream(
|
||||
hub_address: String,
|
||||
token: String,
|
||||
session_id: Option<String>,
|
||||
messages: Vec<serde_json::Value>,
|
||||
send_cwd: bool,
|
||||
last_command: Option<String>,
|
||||
) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<ChatStreamEvent>> + Send>> {
|
||||
Box::pin(async_stream::stream! {
|
||||
ensure_crypto_provider();
|
||||
let endpoint = match hub_url(&hub_address, "/api/cli/chat") {
|
||||
Ok(url) => url,
|
||||
Err(e) => {
|
||||
yield Err(e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Sending SSE request to {endpoint}");
|
||||
|
||||
let os = detect_os();
|
||||
let shell = detect_shell();
|
||||
|
||||
let mut context = serde_json::json!({
|
||||
"os": os,
|
||||
"shell": shell,
|
||||
"pwd": if send_cwd { std::env::current_dir()
|
||||
.ok()
|
||||
.map(|path| path.to_string_lossy().into_owned()) } else { None },
|
||||
"last_command": last_command,
|
||||
});
|
||||
|
||||
if os == "linux" {
|
||||
context["distro"] = serde_json::json!(detect_linux_distribution());
|
||||
}
|
||||
|
||||
let mut request_body = serde_json::json!({
|
||||
"messages": messages,
|
||||
"context": context,
|
||||
});
|
||||
|
||||
if let Some(ref sid) = session_id {
|
||||
trace!("Including session_id in request: {sid}");
|
||||
request_body["session_id"] = serde_json::json!(sid);
|
||||
}
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = match client
|
||||
.post(endpoint.clone())
|
||||
.header("Accept", "text/event-stream")
|
||||
.bearer_auth(&token)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
yield Err(eyre::eyre!("Failed to send SSE request: {}", e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
if status == reqwest::StatusCode::UNAUTHORIZED {
|
||||
error!("SSE request failed with status: {status}, clearing session");
|
||||
let _ = atuin_client::hub::delete_session().await;
|
||||
yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again."));
|
||||
return;
|
||||
}
|
||||
if !status.is_success() {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
error!("SSE request failed ({}): {}", status, body);
|
||||
yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body));
|
||||
return;
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
let mut stream = byte_stream.eventsource();
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(sse_event) => {
|
||||
let event_type = sse_event.event.as_str();
|
||||
let data = sse_event.data.clone();
|
||||
|
||||
debug!(event_type = %event_type, "SSE event received");
|
||||
|
||||
match event_type {
|
||||
"text" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data)
|
||||
&& let Some(content) = json.get("content").and_then(|v| v.as_str())
|
||||
{
|
||||
yield Ok(ChatStreamEvent::TextChunk(content.to_string()));
|
||||
}
|
||||
}
|
||||
"tool_call" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let input = json.get("input").cloned().unwrap_or(serde_json::json!({}));
|
||||
yield Ok(ChatStreamEvent::ToolCall { id, name, input });
|
||||
}
|
||||
}
|
||||
"tool_result" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error });
|
||||
}
|
||||
}
|
||||
"status" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data)
|
||||
&& let Some(state) = json.get("state").and_then(|v| v.as_str())
|
||||
{
|
||||
yield Ok(ChatStreamEvent::Status(state.to_string()));
|
||||
}
|
||||
}
|
||||
"done" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let session_id = json.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
yield Ok(ChatStreamEvent::Done { session_id });
|
||||
} else {
|
||||
yield Ok(ChatStreamEvent::Done { session_id: String::new() });
|
||||
}
|
||||
break;
|
||||
}
|
||||
"error" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string();
|
||||
error!("SSE error: {}", message);
|
||||
yield Ok(ChatStreamEvent::Error(message));
|
||||
} else {
|
||||
error!("SSE error: {}", data);
|
||||
yield Ok(ChatStreamEvent::Error(data));
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
yield Err(eyre::eyre!("SSE error: {}", e));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// Async streaming task — pushes updates to app state via Handle
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn run_chat_stream(
|
||||
handle: Handle<AppState>,
|
||||
endpoint: String,
|
||||
token: String,
|
||||
session_id: Option<String>,
|
||||
messages: Vec<serde_json::Value>,
|
||||
send_cwd: bool,
|
||||
last_command: Option<String>,
|
||||
) {
|
||||
let stream = create_chat_stream(
|
||||
endpoint,
|
||||
token,
|
||||
session_id,
|
||||
messages,
|
||||
send_cwd,
|
||||
last_command,
|
||||
);
|
||||
futures::pin_mut!(stream);
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(ChatStreamEvent::TextChunk(text)) => {
|
||||
trace!(text = %text, "Processing TextChunk");
|
||||
handle.update(move |state| {
|
||||
state.append_streaming_text(&text);
|
||||
});
|
||||
}
|
||||
Ok(ChatStreamEvent::ToolCall { id, name, input }) => {
|
||||
trace!(id = %id, name = %name, "Processing ToolCall");
|
||||
handle.update(move |state| {
|
||||
state.add_tool_call(id, name, input);
|
||||
});
|
||||
}
|
||||
Ok(ChatStreamEvent::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}) => {
|
||||
trace!(tool_use_id = %tool_use_id, "Processing ToolResult");
|
||||
handle.update(move |state| {
|
||||
state.add_tool_result(tool_use_id, content, is_error);
|
||||
});
|
||||
}
|
||||
Ok(ChatStreamEvent::Status(status)) => {
|
||||
trace!(status = %status, "Processing Status");
|
||||
handle.update(move |state| {
|
||||
state.update_streaming_status(&status);
|
||||
});
|
||||
}
|
||||
Ok(ChatStreamEvent::Done { session_id }) => {
|
||||
trace!(session_id = %session_id, "Processing Done");
|
||||
handle.update(move |state| {
|
||||
if !session_id.is_empty() {
|
||||
state.store_session_id(session_id);
|
||||
}
|
||||
state.finalize_streaming();
|
||||
});
|
||||
break;
|
||||
}
|
||||
Ok(ChatStreamEvent::Error(msg)) => {
|
||||
trace!(error = %msg, "Processing Error");
|
||||
handle.update(move |state| {
|
||||
state.streaming_error(msg);
|
||||
});
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
handle.update(move |state| {
|
||||
state.streaming_error(msg);
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// Main TUI entry point
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn run_inline_tui(
|
||||
endpoint: String,
|
||||
token: String,
|
||||
initial_prompt: Option<String>,
|
||||
settings: &atuin_client::settings::Settings,
|
||||
) -> Result<Action> {
|
||||
let initial_state = AppState::new();
|
||||
|
||||
println!();
|
||||
async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Result<Action> {
|
||||
let client_ctx = ClientContext::detect();
|
||||
|
||||
let (tx, rx) = mpsc::channel::<AiTuiEvent>();
|
||||
|
||||
let initial_state = Session::new(ctx.git_root.is_some());
|
||||
|
||||
println!();
|
||||
|
||||
// If there's an initial prompt, send it as a SubmitInput event
|
||||
// so it flows through the same path as user-typed input.
|
||||
if let Some(prompt) = initial_prompt {
|
||||
@@ -396,164 +168,17 @@ async fn run_inline_tui(
|
||||
.ctrl_c(CtrlCBehavior::Deliver)
|
||||
.keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced)
|
||||
.bracketed_paste(true)
|
||||
.with_context(tx)
|
||||
.with_context(tx.clone())
|
||||
.extra_newlines_at_exit(1)
|
||||
.build()?;
|
||||
|
||||
// Support both legacy [ai] send_cwd and new [ai.opening] send_cwd
|
||||
let send_cwd =
|
||||
settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false);
|
||||
|
||||
let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) {
|
||||
let db_path = PathBuf::from(settings.db_path.as_str());
|
||||
match Sqlite::new(db_path, settings.local_timeout).await {
|
||||
Ok(db) => db.last().await.ok().flatten().map(|h| h.command),
|
||||
Err(e) => {
|
||||
debug!("Failed to open history database for read_history: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Event loop: receives AiTuiEvent from components, mutates state via Handle.
|
||||
let h = handle.clone();
|
||||
let ep = endpoint.clone();
|
||||
let tk = token.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let tx = tx.clone();
|
||||
let client_ctx = client_ctx;
|
||||
while let Ok(event) = rx.recv() {
|
||||
match event {
|
||||
AiTuiEvent::InputUpdated(input) => {
|
||||
let input_blank = input.trim().is_empty();
|
||||
|
||||
h.update(move |state| {
|
||||
state.is_input_blank = input_blank;
|
||||
});
|
||||
}
|
||||
AiTuiEvent::SubmitInput(input) => {
|
||||
let input = input.trim().to_string();
|
||||
if input.is_empty() {
|
||||
let h2 = h.clone();
|
||||
h.update(move |state| {
|
||||
if state.has_any_command() {
|
||||
state.exit_action = Some(ExitAction::Execute(
|
||||
state.current_command().unwrap().to_string(),
|
||||
));
|
||||
} else {
|
||||
state.exit_action = Some(ExitAction::Cancel);
|
||||
}
|
||||
h2.exit();
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if input.starts_with('/') {
|
||||
let input_clone = input.clone();
|
||||
h.update(move |state| {
|
||||
state.handle_slash_command(&input_clone);
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start generation and spawn streaming task
|
||||
let ep = ep.clone();
|
||||
let tk = tk.clone();
|
||||
let h2 = h.clone();
|
||||
let lc = last_command.clone();
|
||||
h.update(move |state| {
|
||||
state.start_generating(input);
|
||||
state.start_streaming();
|
||||
state.is_input_blank = true;
|
||||
let messages = state.events_to_messages();
|
||||
let sid = state.session_id.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await;
|
||||
});
|
||||
state.stream_abort = Some(task.abort_handle());
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::SlashCommand(command) => {
|
||||
h.update(move |state| {
|
||||
state.handle_slash_command(&command);
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::CancelGeneration => {
|
||||
h.update(|state| match state.mode {
|
||||
crate::tui::state::AppMode::Generating => {
|
||||
state.cancel_generation();
|
||||
}
|
||||
crate::tui::state::AppMode::Streaming => {
|
||||
state.cancel_streaming();
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::ExecuteCommand => {
|
||||
let h2 = h.clone();
|
||||
h.update(move |state| {
|
||||
let cmd = state.current_command().map(|c| c.to_string());
|
||||
if let Some(cmd) = cmd {
|
||||
if state.is_current_command_dangerous() && !state.confirmation_pending {
|
||||
state.confirmation_pending = true;
|
||||
} else {
|
||||
state.confirmation_pending = false;
|
||||
state.exit_action = Some(ExitAction::Execute(cmd));
|
||||
h2.exit();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::CancelConfirmation => {
|
||||
h.update(move |state| {
|
||||
state.confirmation_pending = false;
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::InsertCommand => {
|
||||
let h2 = h.clone();
|
||||
h.update(move |state| {
|
||||
let cmd = state.current_command().map(|c| c.to_string());
|
||||
if let Some(cmd) = cmd {
|
||||
state.confirmation_pending = false;
|
||||
state.exit_action = Some(ExitAction::Insert(cmd));
|
||||
h2.exit();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::Retry => {
|
||||
let ep = ep.clone();
|
||||
let tk = tk.clone();
|
||||
let h2 = h.clone();
|
||||
let lc = last_command.clone();
|
||||
h.update(move |state| {
|
||||
state.retry();
|
||||
state.start_streaming();
|
||||
let messages = state.events_to_messages();
|
||||
let sid = state.session_id.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
run_chat_stream(h2, ep, tk, sid, messages, send_cwd, lc).await;
|
||||
});
|
||||
state.stream_abort = Some(task.abort_handle());
|
||||
});
|
||||
}
|
||||
|
||||
AiTuiEvent::Exit => {
|
||||
let h2 = h.clone();
|
||||
h.update(move |state| {
|
||||
if let Some(abort) = state.stream_abort.take() {
|
||||
abort.abort();
|
||||
}
|
||||
state.exit_action = Some(ExitAction::Cancel);
|
||||
h2.exit();
|
||||
});
|
||||
}
|
||||
}
|
||||
dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -573,51 +198,125 @@ async fn run_inline_tui(
|
||||
// Helpers
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn hub_url(base: &str, path: &str) -> Result<Url> {
|
||||
let base_with_slash = if base.ends_with('/') {
|
||||
base.to_string()
|
||||
} else {
|
||||
format!("{base}/")
|
||||
};
|
||||
let stripped = path.strip_prefix('/').unwrap_or(path);
|
||||
Url::parse(&base_with_slash)?
|
||||
.join(stripped)
|
||||
.context("failed to build hub URL")
|
||||
}
|
||||
|
||||
fn detect_os() -> String {
|
||||
match std::env::consts::OS {
|
||||
"macos" => "macos".to_string(),
|
||||
"linux" => "linux".to_string(),
|
||||
"windows" => "windows".to_string(),
|
||||
other => format!("Other: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Action {
|
||||
Execute(String),
|
||||
Insert(String),
|
||||
Print(String),
|
||||
enum SetupChoice {
|
||||
EnableAi,
|
||||
DisableKeybind,
|
||||
Cancel,
|
||||
}
|
||||
|
||||
fn emit_shell_result(action: Action, output_for_hook: bool) {
|
||||
if output_for_hook {
|
||||
match action {
|
||||
Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"),
|
||||
Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"),
|
||||
Action::Print(output) => eprintln!("__atuin_ai_print__:{output}"),
|
||||
Action::Cancel => eprintln!("__atuin_ai_cancel__"),
|
||||
}
|
||||
} else {
|
||||
match action {
|
||||
Action::Execute(output) => eprintln!("{output}"),
|
||||
Action::Insert(output) => eprintln!("{output}"),
|
||||
Action::Print(output) => eprintln!("{output}"),
|
||||
Action::Cancel => eprintln!(),
|
||||
fn prompt_ai_setup() -> Result<SetupChoice> {
|
||||
use crossterm::{
|
||||
cursor,
|
||||
event::{self, Event, KeyCode},
|
||||
terminal,
|
||||
};
|
||||
|
||||
let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"];
|
||||
let mut selected: usize = 0;
|
||||
let mut stdout = std::io::stdout();
|
||||
|
||||
// Print header before raw mode so newlines render correctly.
|
||||
// Use stdout because the shell hook swaps stdout/stderr — stdout goes
|
||||
// to the terminal in both hook and non-hook modes.
|
||||
println!();
|
||||
println!(" Atuin AI is not yet configured.");
|
||||
println!();
|
||||
|
||||
terminal::enable_raw_mode().context("failed to enable raw mode")?;
|
||||
struct Guard;
|
||||
impl Drop for Guard {
|
||||
fn drop(&mut self) {
|
||||
let _ = terminal::disable_raw_mode();
|
||||
}
|
||||
}
|
||||
let _guard = Guard;
|
||||
|
||||
crossterm::execute!(stdout, cursor::Hide)?;
|
||||
|
||||
loop {
|
||||
render_setup_options(&mut stdout, &options, selected)?;
|
||||
|
||||
let ev = event::read().context("failed to read key event")?;
|
||||
|
||||
crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?;
|
||||
|
||||
if let Event::Key(key) = ev {
|
||||
match key.code {
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
selected = selected.saturating_sub(1);
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
if selected < options.len() - 1 {
|
||||
selected += 1;
|
||||
}
|
||||
}
|
||||
KeyCode::Enter => break,
|
||||
KeyCode::Esc => {
|
||||
selected = 2;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final render with selection visible
|
||||
render_setup_options(&mut stdout, &options, selected)?;
|
||||
crossterm::execute!(stdout, cursor::Show)?;
|
||||
|
||||
Ok(match selected {
|
||||
0 => SetupChoice::EnableAi,
|
||||
1 => SetupChoice::DisableKeybind,
|
||||
_ => SetupChoice::Cancel,
|
||||
})
|
||||
}
|
||||
|
||||
fn render_setup_options(
|
||||
w: &mut impl std::io::Write,
|
||||
options: &[&str],
|
||||
selected: usize,
|
||||
) -> Result<()> {
|
||||
use crossterm::{
|
||||
style::Stylize,
|
||||
terminal::{Clear, ClearType},
|
||||
};
|
||||
|
||||
for (i, option) in options.iter().enumerate() {
|
||||
if i == selected {
|
||||
write!(w, "\r {}", format!("> {option}").bold().cyan())?;
|
||||
} else {
|
||||
write!(w, "\r {option}")?;
|
||||
}
|
||||
crossterm::execute!(w, Clear(ClearType::UntilNewLine))?;
|
||||
write!(w, "\r\n")?;
|
||||
}
|
||||
w.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn set_ai_enabled(enabled: bool) -> Result<()> {
|
||||
let config_file = atuin_client::settings::Settings::get_config_path()?;
|
||||
let config_str = tokio::fs::read_to_string(&config_file).await?;
|
||||
let mut doc = config_str.parse::<toml_edit::DocumentMut>()?;
|
||||
|
||||
if !doc.contains_key("ai") {
|
||||
doc["ai"] = toml_edit::table();
|
||||
}
|
||||
doc["ai"]["enabled"] = toml_edit::value(enabled);
|
||||
|
||||
tokio::fs::write(&config_file, doc.to_string()).await?;
|
||||
|
||||
if !enabled {
|
||||
println!(
|
||||
"Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.",
|
||||
);
|
||||
println!("Restart your shell for changes to take effect.");
|
||||
// Two printlns to ensure the message is visible above the shell prompt after program ends.
|
||||
println!();
|
||||
println!();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wait_for_login_confirmation() -> Result<bool> {
|
||||
@@ -646,3 +345,27 @@ fn wait_for_login_confirmation() -> Result<bool> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Action {
|
||||
Execute(String),
|
||||
Insert(String),
|
||||
Cancel,
|
||||
}
|
||||
|
||||
fn emit_shell_result(action: Action, output_for_hook: bool) {
|
||||
if output_for_hook {
|
||||
match action {
|
||||
Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"),
|
||||
Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"),
|
||||
Action::Cancel => eprintln!("__atuin_ai_cancel__"),
|
||||
}
|
||||
} else {
|
||||
match action {
|
||||
Action::Execute(output) | Action::Insert(output) => {
|
||||
println!("{output}");
|
||||
}
|
||||
Action::Cancel => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use atuin_client::distro::detect_linux_distribution;
|
||||
use atuin_client::settings::AiCapabilities;
|
||||
|
||||
/// Session-scoped context for the AI chat session.
|
||||
/// Holds the API configuration and client settings needed by the event loop and stream task.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct AppContext {
|
||||
pub endpoint: String,
|
||||
pub token: String,
|
||||
pub send_cwd: bool,
|
||||
pub last_command: Option<String>,
|
||||
pub history_db: Arc<atuin_client::database::Sqlite>,
|
||||
/// Git root of the current working directory, if inside a git repo.
|
||||
/// Resolves through worktrees to the main repo root.
|
||||
pub git_root: Option<PathBuf>,
|
||||
pub capabilities: AiCapabilities,
|
||||
}
|
||||
|
||||
/// Machine identity — computed once per session.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ClientContext {
|
||||
pub os: String,
|
||||
pub shell: Option<String>,
|
||||
pub distro: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientContext {
|
||||
pub(crate) fn detect() -> Self {
|
||||
let os = detect_os();
|
||||
let shell = crate::commands::detect_shell();
|
||||
let distro = if os == "linux" {
|
||||
Some(detect_linux_distribution())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self { os, shell, distro }
|
||||
}
|
||||
|
||||
/// Serialize to the JSON format the API expects for the "context" field.
|
||||
/// The `pwd` field is always dynamic (current working directory), so it's
|
||||
/// computed fresh on each call if `send_cwd` is true.
|
||||
pub(crate) fn to_json(&self, send_cwd: bool, last_command: Option<&str>) -> serde_json::Value {
|
||||
let mut ctx = serde_json::json!({
|
||||
"os": self.os,
|
||||
"shell": self.shell,
|
||||
"pwd": if send_cwd {
|
||||
std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
"last_command": last_command,
|
||||
});
|
||||
|
||||
if let Some(ref distro) = self.distro {
|
||||
ctx["distro"] = serde_json::json!(distro);
|
||||
}
|
||||
|
||||
ctx
|
||||
}
|
||||
}
|
||||
|
||||
/// Move the `detect_os` function here since it's about client identity.
|
||||
fn detect_os() -> String {
|
||||
match std::env::consts::OS {
|
||||
"macos" => "macos".to_string(),
|
||||
"linux" => "linux".to_string(),
|
||||
"windows" => "windows".to_string(),
|
||||
other => format!("Other: {other}"),
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,6 @@
|
||||
pub mod commands;
|
||||
pub mod tui;
|
||||
pub(crate) mod context;
|
||||
pub(crate) mod permissions;
|
||||
pub(crate) mod stream;
|
||||
pub(crate) mod tools;
|
||||
pub(crate) mod tui;
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
use eyre::Result;
|
||||
|
||||
use crate::{permissions::file::RuleFile, tools::PermissableToolCall};
|
||||
|
||||
pub(crate) struct PermissionRequest<'t> {
|
||||
call: &'t (dyn PermissableToolCall + Send + Sync),
|
||||
}
|
||||
|
||||
impl<'t> PermissionRequest<'t> {
|
||||
pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self {
|
||||
Self { call }
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum PermissionResponse {
|
||||
Allowed,
|
||||
Denied,
|
||||
Ask,
|
||||
}
|
||||
|
||||
pub(crate) struct PermissionChecker {
|
||||
files: Vec<RuleFile>,
|
||||
}
|
||||
|
||||
impl PermissionChecker {
|
||||
pub fn new(files: Vec<RuleFile>) -> Self {
|
||||
Self { files }
|
||||
}
|
||||
|
||||
pub async fn check<'t>(
|
||||
&self,
|
||||
request: &'t PermissionRequest<'t>,
|
||||
) -> Result<PermissionResponse> {
|
||||
// Files are in order from deepest to shallowest, so we can stop at the first match.
|
||||
// Within a file, the priority is ask -> deny -> allow
|
||||
// The first rule type that matches is the one that applies, even if a later rule would contradict it.
|
||||
for file in &self.files {
|
||||
for rule in &file.content.permissions.ask {
|
||||
if request.call.matches_rule(rule) {
|
||||
tracing::debug!(
|
||||
"Permission 'ASK' by rule: {} in file: {}",
|
||||
rule,
|
||||
file.path.display()
|
||||
);
|
||||
return Ok(PermissionResponse::Ask);
|
||||
}
|
||||
}
|
||||
|
||||
for rule in &file.content.permissions.deny {
|
||||
if request.call.matches_rule(rule) {
|
||||
tracing::debug!(
|
||||
"Permission 'DENY' by rule: {} in file: {}",
|
||||
rule,
|
||||
file.path.display()
|
||||
);
|
||||
return Ok(PermissionResponse::Denied);
|
||||
}
|
||||
}
|
||||
|
||||
for rule in &file.content.permissions.allow {
|
||||
if request.call.matches_rule(rule) {
|
||||
tracing::debug!(
|
||||
"Permission 'ALLOW' by rule: {} in file: {}",
|
||||
rule,
|
||||
file.path.display()
|
||||
);
|
||||
return Ok(PermissionResponse::Allowed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PermissionResponse::Ask)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::permissions::rule::Rule;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RuleFile {
|
||||
pub path: PathBuf,
|
||||
pub content: RuleFileContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) struct RuleFileContent {
|
||||
pub permissions: RuleFilePermissions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) struct RuleFilePermissions {
|
||||
#[serde(default)]
|
||||
pub allow: Vec<Rule>,
|
||||
#[serde(default)]
|
||||
pub deny: Vec<Rule>,
|
||||
#[serde(default)]
|
||||
pub ask: Vec<Rule>,
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
pub(crate) mod check;
|
||||
pub(crate) mod file;
|
||||
pub(crate) mod resolver;
|
||||
pub(crate) mod rule;
|
||||
pub(crate) mod shell;
|
||||
pub(crate) mod walker;
|
||||
pub(crate) mod writer;
|
||||
@@ -0,0 +1,31 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use eyre::Result;
|
||||
|
||||
use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse};
|
||||
use crate::permissions::walker::PermissionWalker;
|
||||
use crate::permissions::writer;
|
||||
use crate::tools::ClientToolCall;
|
||||
|
||||
/// Resolves permissions for client tool calls by walking the filesystem to find permission files,
|
||||
pub(crate) struct PermissionResolver {
|
||||
checker: PermissionChecker,
|
||||
}
|
||||
|
||||
impl PermissionResolver {
|
||||
/// Create a new resolver that walks from `working_dir` to root for project
|
||||
/// permissions, and also checks the global permissions file.
|
||||
pub async fn new(working_dir: PathBuf) -> Result<Self> {
|
||||
let global_file = writer::global_permissions_path();
|
||||
let mut walker = PermissionWalker::new(working_dir, Some(global_file));
|
||||
walker.walk().await?;
|
||||
let checker = PermissionChecker::new(walker.rules().to_owned());
|
||||
Ok(Self { checker })
|
||||
}
|
||||
|
||||
/// Check whether `tool` is allowed, denied, or needs user confirmation.
|
||||
pub async fn check(&self, tool: &ClientToolCall) -> Result<PermissionResponse> {
|
||||
let request = PermissionRequest::new(tool);
|
||||
self.checker.check(&request).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
static RULE_RE: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum RuleError {
|
||||
#[error("invalid rule format: {0}")]
|
||||
InvalidRule(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct Rule {
|
||||
pub tool: String,
|
||||
pub scope: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Rule {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self.scope.as_ref() {
|
||||
Some(scope) => write!(f, "{}({})", self.tool, scope),
|
||||
None => write!(f, "{}", self.tool),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Rule {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Rule {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Self::try_from(s.as_str()).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
impl TryFrom<&str> for Rule {
|
||||
type Error = RuleError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let value = value.trim();
|
||||
let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap());
|
||||
let caps = re
|
||||
.captures(value)
|
||||
.ok_or(RuleError::InvalidRule(value.to_string()))?;
|
||||
let tool = caps.get(1).unwrap().as_str().to_string();
|
||||
let scope = caps.get(2).map(|m| m.as_str().to_string());
|
||||
Ok(Rule { tool, scope })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rule_try_from() {
|
||||
assert_eq!(
|
||||
Rule::try_from("Read").unwrap(),
|
||||
Rule {
|
||||
tool: "Read".to_string(),
|
||||
scope: None
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
Rule::try_from("Read(*)").unwrap(),
|
||||
Rule {
|
||||
tool: "Read".to_string(),
|
||||
scope: Some("*".to_string())
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
Rule::try_from("Write(*.md)").unwrap(),
|
||||
Rule {
|
||||
tool: "Write".to_string(),
|
||||
scope: Some("*.md".to_string())
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
Rule::try_from("Shell(git commit *)").unwrap(),
|
||||
Rule {
|
||||
tool: "Shell".to_string(),
|
||||
scope: Some("git commit *".to_string())
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
Rule::try_from("Shell(echo ())").unwrap(),
|
||||
Rule {
|
||||
tool: "Shell".to_string(),
|
||||
scope: Some("echo ()".to_string())
|
||||
}
|
||||
);
|
||||
assert!(Rule::try_from("Shell(git commit *").is_err());
|
||||
assert!(Rule::try_from("Shell(git commit *)!").is_err());
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,121 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use eyre::Result;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::permissions::file::{RuleFile, RuleFileContent};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FoundRuleFile {
|
||||
depth: usize,
|
||||
file: RuleFile,
|
||||
}
|
||||
|
||||
pub(crate) struct PermissionWalker {
|
||||
start: PathBuf,
|
||||
/// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`).
|
||||
global_permissions_file: Option<PathBuf>,
|
||||
rules: Vec<RuleFile>,
|
||||
}
|
||||
|
||||
impl PermissionWalker {
|
||||
pub fn new(start: PathBuf, global_permissions_file: Option<PathBuf>) -> Self {
|
||||
Self {
|
||||
start,
|
||||
global_permissions_file,
|
||||
rules: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rules(&self) -> &[RuleFile] {
|
||||
&self.rules
|
||||
}
|
||||
|
||||
/// Walks the filesystem starting from the start path and collecting permission files along the way.
|
||||
/// Walks to the root, then checks the global permissions file, if any.
|
||||
pub async fn walk(&mut self) -> Result<()> {
|
||||
let dirs_to_check: Vec<PathBuf> = self.start.ancestors().map(PathBuf::from).collect();
|
||||
let dir_count = dirs_to_check.len();
|
||||
|
||||
let mut set: JoinSet<Result<Option<FoundRuleFile>>> = JoinSet::new();
|
||||
|
||||
for (index, path) in dirs_to_check.into_iter().enumerate() {
|
||||
set.spawn(async move {
|
||||
match check_dir_for_permissions(&path).await {
|
||||
Ok(Some(rule_file)) => Ok(Some(FoundRuleFile {
|
||||
depth: index,
|
||||
file: rule_file,
|
||||
})),
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern)
|
||||
if let Some(global_path) = self.global_permissions_file.clone() {
|
||||
let depth = dir_count; // sorts after all directory-walk entries
|
||||
set.spawn(async move {
|
||||
match load_permissions_file(&global_path).await {
|
||||
Ok(Some(rule_file)) => Ok(Some(FoundRuleFile {
|
||||
depth,
|
||||
file: rule_file,
|
||||
})),
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let capacity = dir_count + usize::from(self.global_permissions_file.is_some());
|
||||
let mut found = Vec::with_capacity(capacity);
|
||||
while let Some(result) = set.join_next().await {
|
||||
let result = result?; // JoinErrors result in failure to walk the filesystem
|
||||
|
||||
match result {
|
||||
Ok(Some(FoundRuleFile { depth, file })) => {
|
||||
found.push((depth, file));
|
||||
}
|
||||
Ok(None) => {
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Error while walking filesystem for permissions check; skipping: {}",
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// join_next() returns in order of completion, not order of spawn
|
||||
found.sort_by_key(|(depth, _)| *depth);
|
||||
self.rules = found.into_iter().map(|(_, file)| file).collect();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found.
|
||||
async fn check_dir_for_permissions(path: &Path) -> Result<Option<RuleFile>> {
|
||||
let file_path = path.join(".atuin").join("permissions.ai.toml");
|
||||
load_permissions_file(&file_path).await
|
||||
}
|
||||
|
||||
/// Load a permissions file from an exact path. Returns None if the file doesn't exist.
|
||||
async fn load_permissions_file(file_path: &Path) -> Result<Option<RuleFile>> {
|
||||
if !tokio::fs::try_exists(file_path).await? {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let raw = tokio::fs::read_to_string(file_path).await?;
|
||||
let content: RuleFileContent = toml::from_str(&raw)?;
|
||||
|
||||
// Use the file's parent as the rule file path (for logging/debugging)
|
||||
let path = file_path
|
||||
.parent()
|
||||
.map(Path::to_path_buf)
|
||||
.unwrap_or_else(|| file_path.to_path_buf());
|
||||
|
||||
Ok(Some(RuleFile { path, content }))
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
use std::path::Path;
|
||||
|
||||
use eyre::Result;
|
||||
|
||||
use crate::permissions::rule::Rule;
|
||||
|
||||
/// Whether a rule should be added to the allow or deny list.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum RuleDisposition {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
/// Write a permission rule to a `permissions.ai.toml` file.
|
||||
///
|
||||
/// If the file doesn't exist it is created (along with parent directories).
|
||||
/// If it does exist, `toml_edit` is used to append the rule while preserving
|
||||
/// existing formatting and comments.
|
||||
///
|
||||
/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the
|
||||
/// current UI this is fine — the Select widget serializes permission decisions —
|
||||
/// but callers should not invoke this concurrently for the same file.
|
||||
pub(crate) async fn write_rule(
|
||||
file_path: &Path,
|
||||
rule: &Rule,
|
||||
disposition: RuleDisposition,
|
||||
) -> Result<()> {
|
||||
let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) {
|
||||
tokio::fs::read_to_string(file_path).await?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let mut doc: toml_edit::DocumentMut = content.parse()?;
|
||||
|
||||
// Ensure [permissions] table exists
|
||||
if !doc.contains_key("permissions") {
|
||||
doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new());
|
||||
}
|
||||
|
||||
let key = match disposition {
|
||||
RuleDisposition::Allow => "allow",
|
||||
RuleDisposition::Deny => "deny",
|
||||
};
|
||||
|
||||
// Use as_table_like_mut so both standard and inline tables work.
|
||||
let permissions = doc["permissions"]
|
||||
.as_table_like_mut()
|
||||
.ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?;
|
||||
|
||||
// Get or create the array
|
||||
if !permissions.contains_key(key) {
|
||||
permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into()));
|
||||
}
|
||||
|
||||
let array = permissions
|
||||
.get_mut(key)
|
||||
.and_then(|item| item.as_value_mut())
|
||||
.and_then(|v| v.as_array_mut())
|
||||
.ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?;
|
||||
|
||||
// Don't add duplicates
|
||||
let rule_str = rule.to_string();
|
||||
let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str));
|
||||
if !already_present {
|
||||
array.push(rule_str);
|
||||
}
|
||||
|
||||
// Write back, creating parent directories as needed
|
||||
if let Some(parent) = file_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
tokio::fs::write(file_path, doc.to_string()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the path to the project-level permissions file.
|
||||
/// `project_root` is typically a git root or the current working directory.
|
||||
pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf {
|
||||
project_root.join(".atuin").join("permissions.ai.toml")
|
||||
}
|
||||
|
||||
/// Build the path to the global permissions file (sibling of atuin config).
|
||||
pub(crate) fn global_permissions_path() -> std::path::PathBuf {
|
||||
atuin_common::utils::config_dir().join("permissions.ai.toml")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn creates_new_file_with_allow_rule() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file = dir.path().join("permissions.ai.toml");
|
||||
let rule = Rule {
|
||||
tool: "AtuinHistory".to_string(),
|
||||
scope: None,
|
||||
};
|
||||
|
||||
write_rule(&file, &rule, RuleDisposition::Allow)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let content = tokio::fs::read_to_string(&file).await.unwrap();
|
||||
assert!(content.contains("[permissions]"));
|
||||
assert!(content.contains(r#""AtuinHistory""#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn appends_to_existing_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file = dir.path().join("permissions.ai.toml");
|
||||
let existing = r#"# My permissions
|
||||
[permissions]
|
||||
allow = ["Read"]
|
||||
"#;
|
||||
tokio::fs::write(&file, existing).await.unwrap();
|
||||
|
||||
let rule = Rule {
|
||||
tool: "AtuinHistory".to_string(),
|
||||
scope: None,
|
||||
};
|
||||
write_rule(&file, &rule, RuleDisposition::Allow)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let content = tokio::fs::read_to_string(&file).await.unwrap();
|
||||
// Comment preserved
|
||||
assert!(content.contains("# My permissions"));
|
||||
// Both rules present
|
||||
assert!(content.contains(r#""Read""#));
|
||||
assert!(content.contains(r#""AtuinHistory""#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn does_not_duplicate_existing_rule() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file = dir.path().join("permissions.ai.toml");
|
||||
let existing = r#"[permissions]
|
||||
allow = ["AtuinHistory"]
|
||||
"#;
|
||||
tokio::fs::write(&file, existing).await.unwrap();
|
||||
|
||||
let rule = Rule {
|
||||
tool: "AtuinHistory".to_string(),
|
||||
scope: None,
|
||||
};
|
||||
write_rule(&file, &rule, RuleDisposition::Allow)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let content = tokio::fs::read_to_string(&file).await.unwrap();
|
||||
// Should appear exactly once
|
||||
assert_eq!(content.matches("AtuinHistory").count(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handles_inline_table_permissions() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file = dir.path().join("permissions.ai.toml");
|
||||
// Inline table style — as_table_mut() would return None for this
|
||||
let existing = r#"permissions = { allow = ["Read"] }
|
||||
"#;
|
||||
tokio::fs::write(&file, existing).await.unwrap();
|
||||
|
||||
let rule = Rule {
|
||||
tool: "AtuinHistory".to_string(),
|
||||
scope: None,
|
||||
};
|
||||
write_rule(&file, &rule, RuleDisposition::Allow)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let content = tokio::fs::read_to_string(&file).await.unwrap();
|
||||
assert!(content.contains(r#""Read""#));
|
||||
assert!(content.contains(r#""AtuinHistory""#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn writes_deny_rule() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file = dir.path().join("permissions.ai.toml");
|
||||
let rule = Rule {
|
||||
tool: "Shell".to_string(),
|
||||
scope: None,
|
||||
};
|
||||
|
||||
write_rule(&file, &rule, RuleDisposition::Deny)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let content = tokio::fs::read_to_string(&file).await.unwrap();
|
||||
assert!(content.contains("deny"));
|
||||
assert!(content.contains(r#""Shell""#));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,372 @@
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// SSE streaming
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
use std::sync::mpsc;
|
||||
|
||||
use atuin_client::settings::AiCapabilities;
|
||||
use atuin_common::tls::ensure_crypto_provider;
|
||||
|
||||
use eventsource_stream::Eventsource;
|
||||
use eye_declare::Handle;
|
||||
use eyre::{Context, Result};
|
||||
use futures::StreamExt;
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::{
|
||||
context::{AppContext, ClientContext},
|
||||
tools::ClientToolCall,
|
||||
tui::{Session, events::AiTuiEvent},
|
||||
};
|
||||
|
||||
/// Frames that alter the stream lifecycle — terminal or state-changing.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum StreamControl {
|
||||
Done { session_id: String },
|
||||
Error(String),
|
||||
StatusChanged(String),
|
||||
}
|
||||
|
||||
/// Frames that carry conversation content — they mutate the event log.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum StreamContent {
|
||||
TextChunk(String),
|
||||
ToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
is_error: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// A frame from the SSE stream, classified as control or content.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum StreamFrame {
|
||||
Content(StreamContent),
|
||||
Control(StreamControl),
|
||||
}
|
||||
|
||||
/// Per-turn request payload for the chat API.
|
||||
pub(crate) struct ChatRequest {
|
||||
pub messages: Vec<serde_json::Value>,
|
||||
pub session_id: Option<String>,
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
pub(crate) fn new(
|
||||
messages: Vec<serde_json::Value>,
|
||||
session_id: Option<String>,
|
||||
capabilities: &AiCapabilities,
|
||||
) -> Self {
|
||||
let mut caps = vec![];
|
||||
if capabilities.enable_history_search.unwrap_or(true) {
|
||||
caps.push("client_v1_atuin_history".to_string());
|
||||
}
|
||||
if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") {
|
||||
caps.extend(
|
||||
extra
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty()),
|
||||
);
|
||||
}
|
||||
|
||||
Self {
|
||||
messages,
|
||||
session_id,
|
||||
capabilities: caps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_chat_stream(
|
||||
hub_address: String,
|
||||
token: String,
|
||||
request: ChatRequest,
|
||||
client_ctx: ClientContext,
|
||||
send_cwd: bool,
|
||||
last_command: Option<String>,
|
||||
) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> {
|
||||
Box::pin(async_stream::stream! {
|
||||
ensure_crypto_provider();
|
||||
let endpoint = match hub_url(&hub_address, "/api/cli/chat") {
|
||||
Ok(url) => url,
|
||||
Err(e) => {
|
||||
yield Err(e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Sending SSE request to {endpoint}");
|
||||
|
||||
let context = client_ctx.to_json(send_cwd, last_command.as_deref());
|
||||
|
||||
let mut request_body = serde_json::json!({
|
||||
"messages": request.messages,
|
||||
"context": context,
|
||||
"capabilities": request.capabilities,
|
||||
});
|
||||
|
||||
if let Some(ref sid) = request.session_id {
|
||||
tracing::trace!("Including session_id in request: {sid}");
|
||||
request_body["session_id"] = serde_json::json!(sid);
|
||||
}
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = match client
|
||||
.post(endpoint.clone())
|
||||
.header("Accept", "text/event-stream")
|
||||
.bearer_auth(&token)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
yield Err(eyre::eyre!("Failed to send SSE request: {}", e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
if status == reqwest::StatusCode::UNAUTHORIZED {
|
||||
tracing::error!("SSE request failed with status: {status}, clearing session");
|
||||
let _ = atuin_client::hub::delete_session().await;
|
||||
yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again."));
|
||||
return;
|
||||
}
|
||||
if !status.is_success() {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
tracing::error!("SSE request failed ({}): {}", status, body);
|
||||
yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body));
|
||||
return;
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
let mut stream = byte_stream.eventsource();
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(sse_event) => {
|
||||
let event_type = sse_event.event.as_str();
|
||||
let data = sse_event.data.clone();
|
||||
|
||||
tracing::debug!(event_type = %event_type, "SSE event received");
|
||||
|
||||
match event_type {
|
||||
"text" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data)
|
||||
&& let Some(content) = json.get("content").and_then(|v| v.as_str())
|
||||
{
|
||||
yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string())));
|
||||
}
|
||||
}
|
||||
"tool_call" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let input = json.get("input").cloned().unwrap_or(serde_json::json!({}));
|
||||
yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input }));
|
||||
}
|
||||
}
|
||||
"tool_result" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error }));
|
||||
}
|
||||
}
|
||||
"status" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data)
|
||||
&& let Some(state) = json.get("state").and_then(|v| v.as_str())
|
||||
{
|
||||
yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string())));
|
||||
}
|
||||
}
|
||||
"done" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let session_id = json.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
yield Ok(StreamFrame::Control(StreamControl::Done { session_id }));
|
||||
} else {
|
||||
yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() }));
|
||||
}
|
||||
break;
|
||||
}
|
||||
"error" => {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string();
|
||||
tracing::error!("SSE error: {}", message);
|
||||
yield Ok(StreamFrame::Control(StreamControl::Error(message)));
|
||||
} else {
|
||||
tracing::error!("SSE error: {}", data);
|
||||
yield Ok(StreamFrame::Control(StreamControl::Error(data)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
yield Err(eyre::eyre!("SSE error: {}", e));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// Async streaming task — pushes updates to app state via Handle
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
pub(crate) async fn run_chat_stream(
|
||||
handle: Handle<Session>,
|
||||
tx: mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: AppContext,
|
||||
client_ctx: ClientContext,
|
||||
request: ChatRequest,
|
||||
) {
|
||||
let capabilities = request.capabilities.clone();
|
||||
let stream = create_chat_stream(
|
||||
app_ctx.endpoint.clone(),
|
||||
app_ctx.token.clone(),
|
||||
request,
|
||||
client_ctx,
|
||||
app_ctx.send_cwd,
|
||||
app_ctx.last_command.clone(),
|
||||
);
|
||||
futures::pin_mut!(stream);
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(StreamFrame::Content(content)) => {
|
||||
apply_content_frame(&handle, &tx, &capabilities, content);
|
||||
}
|
||||
Ok(StreamFrame::Control(control)) => {
|
||||
let terminal = apply_control_frame(&handle, control);
|
||||
if terminal {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
handle.update(move |state| {
|
||||
state.streaming_error(msg);
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply a content frame to session state.
|
||||
/// Control flow: always continues the stream.
|
||||
fn apply_content_frame(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
capabilities: &[String],
|
||||
content: StreamContent,
|
||||
) {
|
||||
match content {
|
||||
StreamContent::TextChunk(text) => {
|
||||
handle.update(move |state| {
|
||||
state.conversation.append_streaming_text(&text);
|
||||
});
|
||||
}
|
||||
StreamContent::ToolCall { id, name, input } => {
|
||||
if let Ok(tool) = ClientToolCall::try_from((name.as_str(), &input)) {
|
||||
// Enforce capability gating: reject tool calls the client didn't advertise.
|
||||
if let Some(required_cap) = tool.descriptor().capability
|
||||
&& !capabilities.iter().any(|c| c == required_cap)
|
||||
{
|
||||
tracing::warn!(
|
||||
tool = name,
|
||||
capability = required_cap,
|
||||
"Rejecting tool call: capability not advertised"
|
||||
);
|
||||
handle.update(move |state| {
|
||||
state.add_tool_call(id.clone(), name, input.clone());
|
||||
state.conversation.add_tool_result(
|
||||
id,
|
||||
format!("Tool not enabled: capability '{required_cap}' was not advertised by this client"),
|
||||
true,
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Client-side tool — add to tracker and conversation, queue permission check
|
||||
let id_for_event = id.clone();
|
||||
handle.update(move |state| {
|
||||
state.handle_client_tool_call(id_for_event, tool, input);
|
||||
});
|
||||
let _ = tx.send(AiTuiEvent::CheckToolCallPermission(id));
|
||||
} else {
|
||||
// Server-side tool — just add to conversation events
|
||||
handle.update(move |state| {
|
||||
state.add_tool_call(id, name, input);
|
||||
});
|
||||
}
|
||||
}
|
||||
StreamContent::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
handle.update(move |state| {
|
||||
state
|
||||
.conversation
|
||||
.add_tool_result(tool_use_id, content, is_error);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply a control frame to session state.
|
||||
/// Returns true if the stream should terminate.
|
||||
fn apply_control_frame(handle: &Handle<Session>, control: StreamControl) -> bool {
|
||||
match control {
|
||||
StreamControl::StatusChanged(status) => {
|
||||
handle.update(move |state| {
|
||||
state.update_streaming_status(&status);
|
||||
});
|
||||
false
|
||||
}
|
||||
StreamControl::Done { session_id } => {
|
||||
handle.update(move |state| {
|
||||
if !session_id.is_empty() {
|
||||
state.conversation.store_session_id(session_id);
|
||||
}
|
||||
state.finalize_streaming();
|
||||
});
|
||||
true
|
||||
}
|
||||
StreamControl::Error(msg) => {
|
||||
handle.update(move |state| {
|
||||
state.streaming_error(msg);
|
||||
});
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hub_url(base: &str, path: &str) -> Result<Url> {
|
||||
let base_with_slash = if base.ends_with('/') {
|
||||
base.to_string()
|
||||
} else {
|
||||
format!("{base}/")
|
||||
};
|
||||
let stripped = path.strip_prefix('/').unwrap_or(path);
|
||||
Url::parse(&base_with_slash)?
|
||||
.join(stripped)
|
||||
.context("failed to build hub URL")
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
/// Centralized metadata for a tool type.
|
||||
///
|
||||
/// Covers both client-side tools (ones the CLI executes locally) and
|
||||
/// server-side tools (ones the API executes remotely). This is the single
|
||||
/// source of truth for display text and classification.
|
||||
pub(crate) struct ToolDescriptor {
|
||||
/// Canonical wire names for this tool (the names the server sends).
|
||||
pub canonical_names: &'static [&'static str],
|
||||
/// The capability string the client must advertise for this tool to be
|
||||
/// accepted. `None` for server-side tools (always accepted).
|
||||
pub capability: Option<&'static str>,
|
||||
/// Imperative verb for permission prompts (e.g. "read", "run").
|
||||
pub display_verb: &'static str,
|
||||
/// Present-tense progressive verb for spinners (e.g. "Reading file...").
|
||||
pub progressive_verb: &'static str,
|
||||
/// Past-tense verb for summaries (e.g. "Read file").
|
||||
pub past_verb: &'static str,
|
||||
/// Whether this tool is executed client-side (by the CLI).
|
||||
pub is_client: bool,
|
||||
}
|
||||
|
||||
// ── Client-side tool descriptors ──
|
||||
|
||||
pub(crate) const READ: &ToolDescriptor = &ToolDescriptor {
|
||||
canonical_names: &["read_file"],
|
||||
capability: Some("client_v1_read"),
|
||||
display_verb: "read",
|
||||
progressive_verb: "Reading file...",
|
||||
past_verb: "Read file",
|
||||
is_client: true,
|
||||
};
|
||||
|
||||
pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor {
|
||||
canonical_names: &["str_replace", "file_create", "file_insert"],
|
||||
capability: Some("client_v1_write"),
|
||||
display_verb: "write to",
|
||||
progressive_verb: "Writing file...",
|
||||
past_verb: "Wrote file",
|
||||
is_client: true,
|
||||
};
|
||||
|
||||
pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor {
|
||||
canonical_names: &["execute_shell_command"],
|
||||
capability: Some("client_v1_shell"),
|
||||
display_verb: "run",
|
||||
progressive_verb: "Running command...",
|
||||
past_verb: "Ran command",
|
||||
is_client: true,
|
||||
};
|
||||
|
||||
pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor {
|
||||
canonical_names: &["atuin_history"],
|
||||
capability: Some("client_v1_atuin_history"),
|
||||
display_verb: "search your Atuin history for",
|
||||
progressive_verb: "Searching...",
|
||||
past_verb: "Searched",
|
||||
is_client: true,
|
||||
};
|
||||
|
||||
// ── Server-side tool descriptors ──
|
||||
// These appear in tool summaries but aren't client-side tools.
|
||||
|
||||
pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor {
|
||||
canonical_names: &["web_search"],
|
||||
capability: None,
|
||||
display_verb: "search",
|
||||
progressive_verb: "Searching...",
|
||||
past_verb: "Searched",
|
||||
is_client: false,
|
||||
};
|
||||
|
||||
pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor {
|
||||
canonical_names: &["web_scrape"],
|
||||
capability: None,
|
||||
display_verb: "scrape",
|
||||
progressive_verb: "Scraping...",
|
||||
past_verb: "Scraped",
|
||||
is_client: false,
|
||||
};
|
||||
|
||||
/// All known tool descriptors, for lookup by name.
|
||||
const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[
|
||||
READ,
|
||||
WRITE,
|
||||
SHELL,
|
||||
ATUIN_HISTORY,
|
||||
SERVER_SEARCH,
|
||||
SERVER_SCRAPE,
|
||||
];
|
||||
|
||||
/// Look up a tool descriptor by its canonical wire name.
|
||||
/// Returns None for unknown tool names.
|
||||
pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> {
|
||||
ALL_DESCRIPTORS
|
||||
.iter()
|
||||
.find(|d| d.canonical_names.contains(&name))
|
||||
.copied()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,10 +22,11 @@ pub(crate) struct AtuinAi {
|
||||
pub has_command: bool,
|
||||
pub is_input_blank: bool,
|
||||
pub pending_confirmation: bool,
|
||||
pub has_executing_preview: bool,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AtuinAiState {
|
||||
pub(crate) struct AtuinAiState {
|
||||
tx: Option<mpsc::Sender<AiTuiEvent>>,
|
||||
}
|
||||
|
||||
@@ -55,15 +56,24 @@ fn atuin_ai(
|
||||
return EventResult::Ignored;
|
||||
};
|
||||
|
||||
// Ctrl+C always exits
|
||||
// Ctrl+C — interrupt executing command or exit
|
||||
if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') {
|
||||
let _ = tx.send(AiTuiEvent::Exit);
|
||||
if props.has_executing_preview {
|
||||
let _ = tx.send(AiTuiEvent::InterruptToolExecution);
|
||||
} else {
|
||||
let _ = tx.send(AiTuiEvent::Exit);
|
||||
}
|
||||
return EventResult::Consumed;
|
||||
}
|
||||
|
||||
match props.mode {
|
||||
AppMode::Input => match code {
|
||||
KeyCode::Esc => {
|
||||
if props.has_executing_preview {
|
||||
let _ = tx.send(AiTuiEvent::InterruptToolExecution);
|
||||
return EventResult::Consumed;
|
||||
}
|
||||
|
||||
if props.pending_confirmation {
|
||||
let _ = tx.send(AiTuiEvent::CancelConfirmation);
|
||||
return EventResult::Consumed;
|
||||
|
||||
@@ -16,20 +16,12 @@ use ratatui_widgets::paragraph::{Paragraph, Wrap};
|
||||
|
||||
/// A markdown rendering component backed by pulldown-cmark.
|
||||
#[props]
|
||||
pub struct Markdown {
|
||||
pub(crate) struct Markdown {
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
impl Markdown {
|
||||
pub fn new(source: impl Into<String>) -> Self {
|
||||
Self {
|
||||
source: source.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Style configuration for markdown rendering.
|
||||
pub struct MarkdownStyles {
|
||||
pub(crate) struct MarkdownStyles {
|
||||
pub base: Style,
|
||||
pub code_inline: Style,
|
||||
pub code_block: Style,
|
||||
@@ -98,26 +90,22 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat
|
||||
|
||||
let mut style_stack: Vec<Style> = vec![styles.base];
|
||||
let mut in_code_block = false;
|
||||
let mut in_list_item = false;
|
||||
// True until the first paragraph inside a list item has been opened.
|
||||
// The first paragraph should flow inline with the "- " prefix.
|
||||
let mut list_item_first_para = false;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Start(Tag::Strong) => {
|
||||
let bold = style_stack
|
||||
.last()
|
||||
.copied()
|
||||
.unwrap_or(styles.base)
|
||||
.add_modifier(Modifier::BOLD);
|
||||
let bold = style_stack.last().copied().unwrap_or(styles.bold);
|
||||
style_stack.push(bold);
|
||||
}
|
||||
Event::End(TagEnd::Strong) => {
|
||||
style_stack.pop();
|
||||
}
|
||||
Event::Start(Tag::Emphasis) => {
|
||||
let italic = style_stack
|
||||
.last()
|
||||
.copied()
|
||||
.unwrap_or(styles.base)
|
||||
.add_modifier(Modifier::ITALIC);
|
||||
let italic = style_stack.last().copied().unwrap_or(styles.italic);
|
||||
style_stack.push(italic);
|
||||
}
|
||||
Event::End(TagEnd::Emphasis) => {
|
||||
@@ -170,12 +158,17 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat
|
||||
lines.push(Vec::new());
|
||||
}
|
||||
Event::Start(Tag::Paragraph) => {
|
||||
if current_line > 0 || !lines[0].is_empty() {
|
||||
// Two line advances: one to end the current line, one for a blank separator.
|
||||
current_line += 1;
|
||||
lines.push(Vec::new());
|
||||
if in_list_item && list_item_first_para {
|
||||
// First paragraph flows inline with the "- " prefix
|
||||
list_item_first_para = false;
|
||||
} else if current_line > 0 || !lines[0].is_empty() {
|
||||
current_line += 1;
|
||||
lines.push(Vec::new());
|
||||
if !in_list_item {
|
||||
// Blank separator between paragraphs (but not inside list items)
|
||||
current_line += 1;
|
||||
lines.push(Vec::new());
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Paragraph) => {}
|
||||
@@ -197,8 +190,12 @@ fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'stat
|
||||
lines.push(Vec::new());
|
||||
}
|
||||
lines[current_line].push(Span::styled("- ", Style::default().fg(Color::DarkGray)));
|
||||
in_list_item = true;
|
||||
list_item_first_para = true;
|
||||
}
|
||||
Event::End(TagEnd::Item) => {
|
||||
in_list_item = false;
|
||||
}
|
||||
Event::End(TagEnd::Item) => {}
|
||||
Event::Start(Tag::List(_)) => {
|
||||
if current_line > 0 || !lines[0].is_empty() {
|
||||
current_line += 1;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod atuin_ai;
|
||||
pub mod input_box;
|
||||
pub mod markdown;
|
||||
pub(crate) mod atuin_ai;
|
||||
pub(crate) mod input_box;
|
||||
pub(crate) mod markdown;
|
||||
pub(crate) mod select;
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
use std::sync::mpsc;
|
||||
|
||||
use crossterm::event::KeyCode;
|
||||
use eye_declare::{Elements, EventResult, Hooks, Span, Text, View, component, element, props};
|
||||
use ratatui::style::Style;
|
||||
use typed_builder::TypedBuilder;
|
||||
|
||||
use crate::tui::events::AiTuiEvent;
|
||||
|
||||
type OnSelectFn = Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync + 'static>;
|
||||
|
||||
#[derive(TypedBuilder)]
|
||||
pub(crate) struct SelectOption {
|
||||
#[builder(setter(into))]
|
||||
pub label: String,
|
||||
#[builder(setter(into))]
|
||||
pub value: String,
|
||||
#[builder(default = Style::default())]
|
||||
pub label_style: Style,
|
||||
#[builder(default = Style::default().reversed())]
|
||||
pub selected_style: Style,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct PermissionSelectorState {
|
||||
selected_option: usize,
|
||||
tx: Option<mpsc::Sender<AiTuiEvent>>,
|
||||
}
|
||||
|
||||
#[props]
|
||||
pub(crate) struct Select {
|
||||
pub options: Vec<SelectOption>,
|
||||
pub on_select: OnSelectFn,
|
||||
}
|
||||
|
||||
#[component(props = Select, state = PermissionSelectorState)]
|
||||
pub(crate) fn permission_selector(
|
||||
props: &Select,
|
||||
state: &PermissionSelectorState,
|
||||
hooks: &mut Hooks<Select, PermissionSelectorState>,
|
||||
) -> Elements {
|
||||
hooks.use_focusable(true);
|
||||
hooks.use_autofocus();
|
||||
|
||||
hooks.use_context::<mpsc::Sender<AiTuiEvent>>(|tx, _, state| {
|
||||
state.tx = tx.cloned();
|
||||
});
|
||||
|
||||
hooks.use_event(move |event, props, state| {
|
||||
if !event.is_key_press() {
|
||||
return EventResult::Ignored;
|
||||
}
|
||||
|
||||
if let crossterm::event::Event::Key(key) = event {
|
||||
if key.kind != crossterm::event::KeyEventKind::Press {
|
||||
return EventResult::Ignored;
|
||||
}
|
||||
|
||||
match key.code {
|
||||
KeyCode::Up => {
|
||||
state.selected_option =
|
||||
(state.selected_option + props.options.len() - 1) % props.options.len();
|
||||
return EventResult::Consumed;
|
||||
}
|
||||
KeyCode::Down => {
|
||||
state.selected_option = (state.selected_option + 1) % props.options.len();
|
||||
return EventResult::Consumed;
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let option = &props.options[state.selected_option];
|
||||
if let Some(event) = (props.on_select)(option)
|
||||
&& let Some(ref tx) = state.tx
|
||||
{
|
||||
let _ = tx.send(event);
|
||||
}
|
||||
return EventResult::Consumed;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
EventResult::Ignored
|
||||
});
|
||||
|
||||
element!(
|
||||
View {
|
||||
#(for (index, option) in props.options.iter().enumerate() {
|
||||
Text { Span(text: &option.label, style: if index == state.selected_option {
|
||||
option.selected_style
|
||||
} else {
|
||||
option.label_style
|
||||
}) }
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,571 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc;
|
||||
|
||||
use crate::context::{AppContext, ClientContext};
|
||||
use crate::permissions::check::PermissionResponse;
|
||||
use crate::permissions::resolver::PermissionResolver;
|
||||
use crate::permissions::rule::Rule;
|
||||
use crate::permissions::writer::{self, RuleDisposition};
|
||||
use crate::stream::{ChatRequest, run_chat_stream};
|
||||
use crate::tools::{ClientToolCall, ToolPhase};
|
||||
use crate::tui::events::{AiTuiEvent, PermissionResult};
|
||||
use crate::tui::state::{ExitAction, Session};
|
||||
use eye_declare::Handle;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
pub(crate) fn dispatch(
|
||||
handle: &Handle<Session>,
|
||||
event: AiTuiEvent,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
client_ctx: &ClientContext,
|
||||
) {
|
||||
match event {
|
||||
AiTuiEvent::ContinueAfterTools => {
|
||||
on_continue_after_tools(handle, tx, app_ctx, client_ctx);
|
||||
}
|
||||
AiTuiEvent::InputUpdated(input) => {
|
||||
on_input_updated(handle, input);
|
||||
}
|
||||
AiTuiEvent::SubmitInput(input) => {
|
||||
on_submit_input(handle, tx, app_ctx, client_ctx, input);
|
||||
}
|
||||
AiTuiEvent::SlashCommand(cmd) => {
|
||||
on_slash_command(handle, cmd);
|
||||
}
|
||||
AiTuiEvent::CheckToolCallPermission(id) => {
|
||||
on_check_tool_permission(handle, tx, app_ctx, id);
|
||||
}
|
||||
AiTuiEvent::SelectPermission(result) => {
|
||||
on_select_permission(handle, tx, app_ctx, result);
|
||||
}
|
||||
AiTuiEvent::CancelGeneration => {
|
||||
on_cancel_generation(handle);
|
||||
}
|
||||
AiTuiEvent::ExecuteCommand => {
|
||||
on_execute_command(handle);
|
||||
}
|
||||
AiTuiEvent::CancelConfirmation => {
|
||||
on_cancel_confirmation(handle);
|
||||
}
|
||||
AiTuiEvent::InterruptToolExecution => {
|
||||
on_interrupt_tool_execution(handle);
|
||||
}
|
||||
AiTuiEvent::InsertCommand => {
|
||||
on_insert_command(handle);
|
||||
}
|
||||
AiTuiEvent::Retry => {
|
||||
on_retry(handle, tx, app_ctx, client_ctx);
|
||||
}
|
||||
AiTuiEvent::Exit => {
|
||||
on_exit(handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn launch_stream(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
client_ctx: &ClientContext,
|
||||
setup: impl FnOnce(&mut Session) + Send + 'static,
|
||||
) {
|
||||
let h2 = handle.clone();
|
||||
let tx2 = tx.clone();
|
||||
let app = app_ctx.clone();
|
||||
let cc = client_ctx.clone();
|
||||
let caps = app_ctx.capabilities.clone();
|
||||
handle.update(move |state| {
|
||||
(setup)(state);
|
||||
state.start_streaming();
|
||||
let messages = state.conversation.events_to_messages();
|
||||
let sid = state.conversation.session_id.clone();
|
||||
let request = ChatRequest::new(messages, sid, &caps);
|
||||
let task: JoinHandle<()> = tokio::spawn(async move {
|
||||
run_chat_stream(h2, tx2, app, cc, request).await;
|
||||
});
|
||||
state.stream_abort = Some(task.abort_handle());
|
||||
});
|
||||
}
|
||||
|
||||
fn on_continue_after_tools(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
client_ctx: &ClientContext,
|
||||
) {
|
||||
launch_stream(handle, tx, app_ctx, client_ctx, |_state| {});
|
||||
}
|
||||
|
||||
fn on_input_updated(handle: &Handle<Session>, input: String) {
|
||||
let input_blank = input.trim().is_empty();
|
||||
|
||||
handle.update(move |state| {
|
||||
state.interaction.is_input_blank = input_blank;
|
||||
});
|
||||
}
|
||||
|
||||
fn on_submit_input(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
client_ctx: &ClientContext,
|
||||
input: String,
|
||||
) {
|
||||
let input = input.trim().to_string();
|
||||
if input.is_empty() {
|
||||
let h2 = handle.clone();
|
||||
handle.update(move |state| {
|
||||
if state.conversation.has_any_command() {
|
||||
state.exit_action = Some(ExitAction::Execute(
|
||||
state.conversation.current_command().unwrap().to_string(),
|
||||
));
|
||||
} else {
|
||||
state.exit_action = Some(ExitAction::Cancel);
|
||||
}
|
||||
h2.exit();
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if input.starts_with('/') {
|
||||
handle.update(move |state| {
|
||||
state.conversation.handle_slash_command(&input);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Start generation and spawn streaming task
|
||||
launch_stream(handle, tx, app_ctx, client_ctx, |state| {
|
||||
state.start_generating(input);
|
||||
state.interaction.is_input_blank = true;
|
||||
});
|
||||
}
|
||||
|
||||
fn on_slash_command(handle: &Handle<Session>, command: String) {
|
||||
handle.update(move |state| {
|
||||
state.conversation.handle_slash_command(&command);
|
||||
});
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// Tool execution dispatch
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute a tool call. Handles Shell tools (streaming with preview) and
|
||||
/// non-shell tools (synchronous) uniformly.
|
||||
fn execute_tool(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
tool_id: String,
|
||||
tool: ClientToolCall,
|
||||
db: &std::sync::Arc<atuin_client::database::Sqlite>,
|
||||
) {
|
||||
match &tool {
|
||||
ClientToolCall::Shell(shell_call) => {
|
||||
let shell_call = shell_call.clone();
|
||||
execute_shell_tool(handle, tx, &tool_id, &shell_call);
|
||||
}
|
||||
_ => {
|
||||
execute_simple_tool(handle, tx, tool_id, tool, db);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a non-shell tool and finish the tool call.
|
||||
/// The ToolCall event is already in the conversation (added by handle_client_tool_call).
|
||||
fn execute_simple_tool(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
tool_id: String,
|
||||
tool: ClientToolCall,
|
||||
db: &std::sync::Arc<atuin_client::database::Sqlite>,
|
||||
) {
|
||||
let h = handle.clone();
|
||||
let tx = tx.clone();
|
||||
let db = db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let outcome = tool.execute(&db).await;
|
||||
h.update(move |state| {
|
||||
state.finish_tool_call(&tool_id, outcome);
|
||||
if !state.tool_tracker.has_pending() {
|
||||
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Execute a shell tool with streaming VT100 preview.
|
||||
fn execute_shell_tool(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
tool_id: &str,
|
||||
shell_call: &crate::tools::ShellToolCall,
|
||||
) {
|
||||
let h = handle.clone();
|
||||
let tx = tx.clone();
|
||||
let shell_call = shell_call.clone();
|
||||
let command = shell_call.command.clone();
|
||||
let tc_id = tool_id.to_string();
|
||||
|
||||
// 1. Set up channels for streaming output and interruption
|
||||
let (output_tx, mut output_rx) = tokio::sync::mpsc::channel::<Vec<String>>(32);
|
||||
let (abort_tx, abort_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
|
||||
// 2. Mark as executing with preview and store the abort sender on the tracker entry
|
||||
let tc_id_setup = tc_id.clone();
|
||||
h.update(move |state| {
|
||||
if let Some(tracked) = state.tool_tracker.get_mut(&tc_id_setup) {
|
||||
tracked.mark_executing_preview(command);
|
||||
tracked.abort_tx = Some(abort_tx);
|
||||
}
|
||||
});
|
||||
|
||||
// 3. Spawn a task to consume output updates and feed them to state
|
||||
let h_output = h.clone();
|
||||
let preview_id = tc_id.clone();
|
||||
let output_task = tokio::spawn(async move {
|
||||
while let Some(lines) = output_rx.recv().await {
|
||||
let id = preview_id.clone();
|
||||
h_output.update(move |state| {
|
||||
if let Some(tracked) = state.tool_tracker.get_mut(&id)
|
||||
&& let ToolPhase::ExecutingWithPreview {
|
||||
ref mut output_lines,
|
||||
..
|
||||
} = tracked.phase
|
||||
{
|
||||
*output_lines = lines;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// 4. Spawn the streaming execution task
|
||||
let tc_id_finish = tc_id;
|
||||
tokio::spawn(async move {
|
||||
let outcome =
|
||||
crate::tools::execute_shell_command_streaming(&shell_call, output_tx, abort_rx).await;
|
||||
|
||||
// Wait for the output task to finish so the final preview lines are captured
|
||||
let _ = output_task.await;
|
||||
|
||||
h.update(move |state| {
|
||||
state.finish_tool_call(&tc_id_finish, outcome);
|
||||
if !state.tool_tracker.has_pending() {
|
||||
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// Permission handlers
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn on_check_tool_permission(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
id: String,
|
||||
) {
|
||||
let h2 = handle.clone();
|
||||
let tx_for_task = tx.clone();
|
||||
let db = app_ctx.history_db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let id_for_error = id.clone();
|
||||
let result = check_tool_permission_inner(&h2, &tx_for_task, &db, id).await;
|
||||
|
||||
// If the inner function didn't handle the tool (returned an error message),
|
||||
// finish the tool call with that error so the conversation doesn't stall.
|
||||
if let Err(error_msg) = result {
|
||||
let tx = tx_for_task.clone();
|
||||
h2.update(move |state| {
|
||||
state.finish_tool_call(&id_for_error, crate::tools::ToolOutcome::Error(error_msg));
|
||||
if !state.tool_tracker.has_pending() {
|
||||
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Inner permission check that returns Err(message) if the tool call should be
|
||||
/// finished with an error. Returns Ok(()) if the tool was handled (executed,
|
||||
/// denied, or sent to the permission UI).
|
||||
async fn check_tool_permission_inner(
|
||||
h2: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
db: &std::sync::Arc<atuin_client::database::Sqlite>,
|
||||
id: String,
|
||||
) -> Result<(), String> {
|
||||
// 1. Fetch the tracked tool's data
|
||||
let id_for_fetch = id.clone();
|
||||
let (tool, target_dir) = h2
|
||||
.fetch(move |state| {
|
||||
state
|
||||
.tool_tracker
|
||||
.get(&id_for_fetch)
|
||||
.map(|t| (t.tool.clone(), t.target_dir().map(PathBuf::from)))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| format!("Internal error fetching tool state: {e}"))?
|
||||
.ok_or_else(|| "Internal error: tool not found in tracker".to_string())?;
|
||||
|
||||
// 2. Resolve working directory
|
||||
let working_dir = target_dir
|
||||
.or_else(|| std::env::current_dir().ok())
|
||||
.ok_or_else(|| "Could not determine working directory".to_string())?;
|
||||
|
||||
// 3. Create permission resolver and check
|
||||
let resolver = PermissionResolver::new(working_dir)
|
||||
.await
|
||||
.map_err(|e| format!("Permission check failed: {e}"))?;
|
||||
|
||||
let response = resolver
|
||||
.check(&tool)
|
||||
.await
|
||||
.map_err(|e| format!("Permission check failed: {e}"))?;
|
||||
|
||||
// 4. Handle response — all paths here handle the tool, so return Ok
|
||||
let id_clone = id.clone();
|
||||
match response {
|
||||
PermissionResponse::Allowed => {
|
||||
execute_tool(h2, tx, id, tool, db);
|
||||
}
|
||||
PermissionResponse::Denied => {
|
||||
let tx = tx.clone();
|
||||
h2.update(move |state| {
|
||||
state.finish_tool_call(
|
||||
&id_clone,
|
||||
crate::tools::ToolOutcome::Error(
|
||||
"Permission denied on the user's system".to_string(),
|
||||
),
|
||||
);
|
||||
if !state.tool_tracker.has_pending() {
|
||||
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
|
||||
}
|
||||
});
|
||||
}
|
||||
PermissionResponse::Ask => {
|
||||
h2.update(move |state| {
|
||||
if let Some(tracked) = state.tool_tracker.get_mut(&id_clone) {
|
||||
tracked.mark_asking();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn on_select_permission(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
permission: PermissionResult,
|
||||
) {
|
||||
let tx = tx.clone();
|
||||
let h2 = handle.clone();
|
||||
|
||||
match permission {
|
||||
PermissionResult::Allow => {
|
||||
// Fetch the tool that's asking for permission, then execute it
|
||||
let db = app_ctx.history_db.clone();
|
||||
tokio::spawn(async move {
|
||||
let Ok(Some((tool_id, tool))) = h2
|
||||
.fetch(move |state| {
|
||||
state
|
||||
.tool_tracker
|
||||
.asking_for_permission()
|
||||
.map(|t| (t.id.clone(), t.tool.clone()))
|
||||
})
|
||||
.await
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
execute_tool(&h2, &tx, tool_id, tool, &db);
|
||||
});
|
||||
}
|
||||
PermissionResult::AlwaysAllowInDir => {
|
||||
let db = app_ctx.history_db.clone();
|
||||
let git_root = app_ctx.git_root.clone();
|
||||
tokio::spawn(async move {
|
||||
let Ok(Some((tool_id, tool))) = h2
|
||||
.fetch(move |state| {
|
||||
state
|
||||
.tool_tracker
|
||||
.asking_for_permission()
|
||||
.map(|t| (t.id.clone(), t.tool.clone()))
|
||||
})
|
||||
.await
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Write the rule to the project (git root) or cwd permissions file
|
||||
let project_root = git_root
|
||||
.or_else(|| std::env::current_dir().ok())
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
let file_path = writer::project_permissions_path(&project_root);
|
||||
let rule = Rule {
|
||||
tool: tool.rule_name().to_string(),
|
||||
scope: None,
|
||||
};
|
||||
if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await
|
||||
{
|
||||
tracing::error!("Failed to write project permission rule: {e}");
|
||||
}
|
||||
|
||||
execute_tool(&h2, &tx, tool_id, tool, &db);
|
||||
});
|
||||
}
|
||||
PermissionResult::AlwaysAllow => {
|
||||
let db = app_ctx.history_db.clone();
|
||||
tokio::spawn(async move {
|
||||
let Ok(Some((tool_id, tool))) = h2
|
||||
.fetch(move |state| {
|
||||
state
|
||||
.tool_tracker
|
||||
.asking_for_permission()
|
||||
.map(|t| (t.id.clone(), t.tool.clone()))
|
||||
})
|
||||
.await
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Write the rule to the global permissions file
|
||||
let file_path = writer::global_permissions_path();
|
||||
let rule = Rule {
|
||||
tool: tool.rule_name().to_string(),
|
||||
scope: None,
|
||||
};
|
||||
if let Err(e) = writer::write_rule(&file_path, &rule, RuleDisposition::Allow).await
|
||||
{
|
||||
tracing::error!("Failed to write global permission rule: {e}");
|
||||
}
|
||||
|
||||
execute_tool(&h2, &tx, tool_id, tool, &db);
|
||||
});
|
||||
}
|
||||
PermissionResult::Deny => {
|
||||
h2.update(move |state| {
|
||||
let Some(tracked) = state.tool_tracker.asking_for_permission() else {
|
||||
return;
|
||||
};
|
||||
let tool_id = tracked.id.clone();
|
||||
|
||||
state.finish_tool_call(
|
||||
&tool_id,
|
||||
crate::tools::ToolOutcome::Error("Permission denied by the user".to_string()),
|
||||
);
|
||||
if !state.tool_tracker.has_pending() {
|
||||
let _ = tx.send(AiTuiEvent::ContinueAfterTools);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
// Other handlers
|
||||
// ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn on_cancel_generation(handle: &Handle<Session>) {
|
||||
handle.update(|state| match state.interaction.mode {
|
||||
crate::tui::state::AppMode::Generating => {
|
||||
state.cancel_generation();
|
||||
}
|
||||
crate::tui::state::AppMode::Streaming => {
|
||||
state.cancel_streaming();
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
}
|
||||
|
||||
fn on_execute_command(handle: &Handle<Session>) {
|
||||
let h2 = handle.clone();
|
||||
handle.update(move |state| {
|
||||
let cmd = state.conversation.current_command().map(|c| c.to_string());
|
||||
if let Some(cmd) = cmd {
|
||||
if state.conversation.is_current_command_dangerous()
|
||||
&& !state.interaction.confirmation_pending
|
||||
{
|
||||
state.interaction.confirmation_pending = true;
|
||||
} else {
|
||||
state.interaction.confirmation_pending = false;
|
||||
state.exit_action = Some(ExitAction::Execute(cmd));
|
||||
h2.exit();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn on_cancel_confirmation(handle: &Handle<Session>) {
|
||||
handle.update(move |state| {
|
||||
state.interaction.confirmation_pending = false;
|
||||
});
|
||||
}
|
||||
|
||||
fn on_insert_command(handle: &Handle<Session>) {
|
||||
let h2 = handle.clone();
|
||||
handle.update(move |state| {
|
||||
let cmd = state.conversation.current_command().map(|c| c.to_string());
|
||||
if let Some(cmd) = cmd {
|
||||
state.interaction.confirmation_pending = false;
|
||||
state.exit_action = Some(ExitAction::Insert(cmd));
|
||||
h2.exit();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn on_retry(
|
||||
handle: &Handle<Session>,
|
||||
tx: &mpsc::Sender<AiTuiEvent>,
|
||||
app_ctx: &AppContext,
|
||||
client_ctx: &ClientContext,
|
||||
) {
|
||||
launch_stream(handle, tx, app_ctx, client_ctx, |state| {
|
||||
state.retry();
|
||||
});
|
||||
}
|
||||
|
||||
fn on_exit(handle: &Handle<Session>) {
|
||||
let h2 = handle.clone();
|
||||
handle.update(move |state| {
|
||||
if let Some(abort) = state.stream_abort.take() {
|
||||
abort.abort();
|
||||
}
|
||||
state.exit_action = Some(ExitAction::Cancel);
|
||||
h2.exit();
|
||||
});
|
||||
}
|
||||
|
||||
fn on_interrupt_tool_execution(handle: &Handle<Session>) {
|
||||
handle.update(move |state| {
|
||||
// Find executing previews, send interrupt, and mark as interrupted
|
||||
for tracked in state.tool_tracker.iter_mut() {
|
||||
if let ToolPhase::ExecutingWithPreview {
|
||||
ref mut interrupted,
|
||||
ref mut exit_code,
|
||||
..
|
||||
} = tracked.phase
|
||||
{
|
||||
*interrupted = true;
|
||||
if exit_code.is_none() {
|
||||
*exit_code = Some(-1);
|
||||
}
|
||||
// Send interrupt signal via the tracker entry's abort channel
|
||||
if let Some(abort_tx) = tracked.abort_tx.take() {
|
||||
let _ = abort_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The spawned execution task will handle finalizing and sending
|
||||
// ContinueAfterTools when the process exits. Input mode is already active.
|
||||
});
|
||||
}
|
||||
@@ -5,13 +5,20 @@
|
||||
/// eye-declare's context system. The main event loop in `inline.rs`
|
||||
/// receives them and mutates `AppState` accordingly.
|
||||
#[derive(Debug)]
|
||||
pub enum AiTuiEvent {
|
||||
pub(crate) enum AiTuiEvent {
|
||||
/// User updated the input text
|
||||
InputUpdated(String),
|
||||
/// User submitted text input (Enter in Input mode)
|
||||
SubmitInput(String),
|
||||
/// User entered a slash command (e.g. "/help")
|
||||
#[allow(unused)]
|
||||
SlashCommand(String),
|
||||
/// Check the permission for a tool call
|
||||
CheckToolCallPermission(String),
|
||||
/// User selected a permission
|
||||
SelectPermission(PermissionResult),
|
||||
/// Continue after client tools have completed
|
||||
ContinueAfterTools,
|
||||
/// Cancel active generation or streaming (Esc during Generating/Streaming)
|
||||
CancelGeneration,
|
||||
/// Execute the suggested command
|
||||
@@ -20,8 +27,18 @@ pub enum AiTuiEvent {
|
||||
InsertCommand,
|
||||
/// Cancel confirmation of dangerous command
|
||||
CancelConfirmation,
|
||||
/// Interrupt a running tool execution (Ctrl+C during ExecutingPreview)
|
||||
InterruptToolExecution,
|
||||
/// Retry after error
|
||||
Retry,
|
||||
/// Exit the application
|
||||
Exit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum PermissionResult {
|
||||
Allow,
|
||||
AlwaysAllowInDir,
|
||||
AlwaysAllow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod components;
|
||||
pub mod events;
|
||||
pub mod state;
|
||||
pub mod view;
|
||||
pub(crate) mod components;
|
||||
pub(crate) mod dispatch;
|
||||
pub(crate) mod events;
|
||||
pub(crate) mod state;
|
||||
pub(crate) mod view;
|
||||
|
||||
pub use state::{AppMode, AppState, ConversationEvent, ExitAction};
|
||||
pub(crate) use state::{ConversationEvent, Session};
|
||||
|
||||
+326
-284
@@ -5,9 +5,11 @@
|
||||
|
||||
use tokio::task::AbortHandle;
|
||||
|
||||
use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker};
|
||||
|
||||
/// Streaming status indicators from server
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum StreamingStatus {
|
||||
pub(crate) enum StreamingStatus {
|
||||
Processing,
|
||||
Searching,
|
||||
Thinking,
|
||||
@@ -15,7 +17,7 @@ pub enum StreamingStatus {
|
||||
}
|
||||
|
||||
impl StreamingStatus {
|
||||
pub fn from_status_str(s: &str) -> Self {
|
||||
pub(crate) fn from_status_str(s: &str) -> Self {
|
||||
match s {
|
||||
"processing" => Self::Processing,
|
||||
"searching" => Self::Searching,
|
||||
@@ -23,20 +25,11 @@ impl StreamingStatus {
|
||||
_ => Self::Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_text(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Processing => "Processing...",
|
||||
Self::Searching => "Searching...",
|
||||
Self::Thinking => "Thinking...",
|
||||
Self::WaitingForTools => "Waiting for tools...",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversation event types matching the API protocol
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConversationEvent {
|
||||
pub(crate) enum ConversationEvent {
|
||||
/// User message (what the user typed)
|
||||
UserMessage { content: String },
|
||||
/// Text content from assistant (streamed or complete)
|
||||
@@ -62,48 +55,8 @@ pub enum ConversationEvent {
|
||||
}
|
||||
|
||||
impl ConversationEvent {
|
||||
/// Convert to JSON for API calls
|
||||
pub fn to_json(&self) -> serde_json::Value {
|
||||
match self {
|
||||
ConversationEvent::UserMessage { content } => serde_json::json!({
|
||||
"type": "user_message",
|
||||
"content": content
|
||||
}),
|
||||
ConversationEvent::Text { content } => serde_json::json!({
|
||||
"type": "text",
|
||||
"content": content
|
||||
}),
|
||||
ConversationEvent::ToolCall { id, name, input } => serde_json::json!({
|
||||
"type": "tool_call",
|
||||
"id": id,
|
||||
"name": name,
|
||||
"input": input
|
||||
}),
|
||||
ConversationEvent::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
} => serde_json::json!({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": content,
|
||||
"is_error": is_error
|
||||
}),
|
||||
ConversationEvent::OutOfBandOutput {
|
||||
name,
|
||||
command,
|
||||
content,
|
||||
} => serde_json::json!({
|
||||
"type": "out_of_band_output",
|
||||
"name": name,
|
||||
"command": command,
|
||||
"content": content
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract command from a suggest_command tool call
|
||||
pub fn as_command(&self) -> Option<&str> {
|
||||
pub(crate) fn as_command(&self) -> Option<&str> {
|
||||
if let ConversationEvent::ToolCall { name, input, .. } = self
|
||||
&& name == "suggest_command"
|
||||
{
|
||||
@@ -113,8 +66,9 @@ impl ConversationEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Application mode for key handling and footer text.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum AppMode {
|
||||
pub(crate) enum AppMode {
|
||||
/// User is typing input
|
||||
Input,
|
||||
/// Waiting for generation (showing spinner)
|
||||
@@ -126,7 +80,7 @@ pub enum AppMode {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExitAction {
|
||||
pub(crate) enum ExitAction {
|
||||
/// Run the command
|
||||
Execute(String),
|
||||
/// Insert command without running
|
||||
@@ -135,47 +89,20 @@ pub enum ExitAction {
|
||||
Cancel,
|
||||
}
|
||||
|
||||
/// Application state — the domain model
|
||||
///
|
||||
/// Conversation is stored as a sequence of events matching the API protocol.
|
||||
/// The view function derives the UI from this state.
|
||||
/// Owned event log and session ID
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
/// Current application mode
|
||||
pub mode: AppMode,
|
||||
pub(crate) struct Conversation {
|
||||
/// Conversation events (source of truth, matches API protocol)
|
||||
pub events: Vec<ConversationEvent>,
|
||||
/// Current error message
|
||||
pub error: Option<String>,
|
||||
/// Exit action (set when exiting)
|
||||
pub exit_action: Option<ExitAction>,
|
||||
/// Session ID from server
|
||||
pub session_id: Option<String>,
|
||||
/// Current streaming status
|
||||
pub streaming_status: Option<StreamingStatus>,
|
||||
/// Whether the input is blank
|
||||
pub is_input_blank: bool,
|
||||
/// Whether current turn was interrupted by user
|
||||
pub was_interrupted: bool,
|
||||
/// True when user has pressed Enter once on a dangerous command
|
||||
pub confirmation_pending: bool,
|
||||
/// Abort handle for the active streaming task, if any
|
||||
pub stream_abort: Option<AbortHandle>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
impl Conversation {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mode: AppMode::Input,
|
||||
events: Vec::new(),
|
||||
error: None,
|
||||
exit_action: None,
|
||||
session_id: None,
|
||||
streaming_status: None,
|
||||
is_input_blank: false,
|
||||
was_interrupted: false,
|
||||
confirmation_pending: false,
|
||||
stream_abort: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,16 +122,57 @@ impl AppState {
|
||||
i += 1;
|
||||
}
|
||||
ConversationEvent::Text { content } => {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
}));
|
||||
i += 1;
|
||||
// Check if the next event(s) are ToolCalls — if so, combine
|
||||
// into a single assistant message with mixed content blocks.
|
||||
let next_is_tool_call = events
|
||||
.get(i + 1)
|
||||
.is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. }));
|
||||
|
||||
if next_is_tool_call {
|
||||
let mut content_blocks = Vec::new();
|
||||
|
||||
if !content.is_empty() {
|
||||
content_blocks.push(serde_json::json!({
|
||||
"type": "text",
|
||||
"text": content
|
||||
}));
|
||||
}
|
||||
|
||||
while let Some(ConversationEvent::ToolCall {
|
||||
id, name, input, ..
|
||||
}) = events.get(i + 1)
|
||||
{
|
||||
content_blocks.push(serde_json::json!({
|
||||
"type": "tool_use",
|
||||
"id": id,
|
||||
"name": name,
|
||||
"input": input
|
||||
}));
|
||||
i += 1;
|
||||
}
|
||||
|
||||
messages.push(serde_json::json!({
|
||||
"role": "assistant",
|
||||
"content": content_blocks
|
||||
}));
|
||||
i += 1;
|
||||
} else {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
}));
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
ConversationEvent::ToolCall { .. } => {
|
||||
// ToolCalls without preceding Text (shouldn't normally happen,
|
||||
// but handle defensively)
|
||||
let mut tool_uses = Vec::new();
|
||||
while i < events.len() {
|
||||
if let ConversationEvent::ToolCall { id, name, input } = &events[i] {
|
||||
if let ConversationEvent::ToolCall {
|
||||
id, name, input, ..
|
||||
} = &events[i]
|
||||
{
|
||||
tool_uses.push(serde_json::json!({
|
||||
"type": "tool_use",
|
||||
"id": id,
|
||||
@@ -247,53 +215,42 @@ impl AppState {
|
||||
messages
|
||||
}
|
||||
|
||||
// ===== Generation lifecycle methods =====
|
||||
/// Get the most recent command from events
|
||||
pub fn current_command(&self) -> Option<&str> {
|
||||
self.events.iter().rev().find_map(|e| e.as_command())
|
||||
}
|
||||
|
||||
/// Start generating from submitted input
|
||||
pub fn start_generating(&mut self, input: String) {
|
||||
/// Check if any turn in the conversation has a command
|
||||
pub fn has_any_command(&self) -> bool {
|
||||
self.events.iter().any(|e| {
|
||||
if let ConversationEvent::ToolCall { name, input, .. } = e {
|
||||
name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the most recent command is marked dangerous
|
||||
pub fn is_current_command_dangerous(&self) -> bool {
|
||||
self.events
|
||||
.push(ConversationEvent::UserMessage { content: input });
|
||||
self.mode = AppMode::Generating;
|
||||
}
|
||||
|
||||
/// Generation error occurred
|
||||
pub fn generation_error(&mut self, error: String) {
|
||||
self.error = Some(error);
|
||||
self.mode = AppMode::Error;
|
||||
}
|
||||
|
||||
/// Cancel during generation
|
||||
pub fn cancel_generation(&mut self) {
|
||||
if let Some(abort) = self.stream_abort.take() {
|
||||
abort.abort();
|
||||
}
|
||||
if let Some(ConversationEvent::UserMessage { .. }) = self.events.last() {
|
||||
self.events.pop();
|
||||
}
|
||||
self.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
// ===== Streaming lifecycle methods =====
|
||||
|
||||
/// Start streaming response.
|
||||
/// Pushes an empty Text event that will be mutated in-place as chunks arrive.
|
||||
pub fn start_streaming(&mut self) {
|
||||
self.events.push(ConversationEvent::Text {
|
||||
content: String::new(),
|
||||
});
|
||||
self.streaming_status = None;
|
||||
self.was_interrupted = false;
|
||||
self.mode = AppMode::Streaming;
|
||||
}
|
||||
|
||||
/// Store session ID from server response
|
||||
pub fn store_session_id(&mut self, session_id: String) {
|
||||
self.session_id = Some(session_id);
|
||||
}
|
||||
|
||||
/// Update streaming status from SSE event
|
||||
pub fn update_streaming_status(&mut self, status: &str) {
|
||||
self.streaming_status = Some(StreamingStatus::from_status_str(status));
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|e| {
|
||||
if let ConversationEvent::ToolCall { name, input, .. } = e
|
||||
&& name == "suggest_command"
|
||||
{
|
||||
let danger_level = input
|
||||
.get("danger")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("low");
|
||||
return Some(
|
||||
danger_level == "high" || danger_level == "medium" || danger_level == "med",
|
||||
);
|
||||
}
|
||||
None
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the last Text event's content (the streaming buffer).
|
||||
@@ -307,28 +264,15 @@ impl AppState {
|
||||
})
|
||||
}
|
||||
|
||||
/// Cancel streaming with context preservation
|
||||
pub fn cancel_streaming(&mut self) {
|
||||
if let Some(abort) = self.stream_abort.take() {
|
||||
abort.abort();
|
||||
}
|
||||
self.was_interrupted = true;
|
||||
|
||||
if let Some(content) = self.streaming_content_mut() {
|
||||
let trimmed = content.trim_start().to_string();
|
||||
if trimmed.is_empty() {
|
||||
// Remove the empty text event
|
||||
*content = String::new();
|
||||
/// Remove trailing empty Text events from the events list
|
||||
fn remove_empty_trailing_text(&mut self) {
|
||||
while let Some(ConversationEvent::Text { content }) = self.events.last() {
|
||||
if content.is_empty() {
|
||||
self.events.pop();
|
||||
} else {
|
||||
*content = format!("{trimmed}\n\n[User cancelled this generation]");
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Remove trailing empty Text events
|
||||
self.remove_empty_trailing_text();
|
||||
|
||||
self.streaming_status = None;
|
||||
self.confirmation_pending = false;
|
||||
self.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
/// Append text chunk during streaming (mutates the last Text event in-place)
|
||||
@@ -354,26 +298,6 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a tool call event during streaming.
|
||||
/// The current streaming text is already in events, so we just push the tool call.
|
||||
pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) {
|
||||
// Trim the streaming text event
|
||||
if let Some(content) = self.streaming_content_mut() {
|
||||
let trimmed = content.trim_start().to_string();
|
||||
*content = trimmed;
|
||||
}
|
||||
self.remove_empty_trailing_text();
|
||||
|
||||
let is_suggest_command = name == "suggest_command";
|
||||
self.events
|
||||
.push(ConversationEvent::ToolCall { id, name, input });
|
||||
|
||||
if is_suggest_command {
|
||||
self.streaming_status = None;
|
||||
self.mode = AppMode::Input;
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a tool result event during streaming
|
||||
pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) {
|
||||
self.events.push(ConversationEvent::ToolResult {
|
||||
@@ -383,47 +307,9 @@ impl AppState {
|
||||
});
|
||||
}
|
||||
|
||||
/// Finalize streaming — trim the accumulated text and change mode
|
||||
pub fn finalize_streaming(&mut self) {
|
||||
if let Some(content) = self.streaming_content_mut() {
|
||||
let trimmed = content.trim_start().to_string();
|
||||
*content = trimmed;
|
||||
}
|
||||
self.remove_empty_trailing_text();
|
||||
self.streaming_status = None;
|
||||
self.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
/// Streaming error — remove the partial text event
|
||||
pub fn streaming_error(&mut self, error: String) {
|
||||
self.remove_empty_trailing_text();
|
||||
self.error = Some(error);
|
||||
self.mode = AppMode::Error;
|
||||
}
|
||||
|
||||
/// Remove trailing empty Text events from the events list
|
||||
fn remove_empty_trailing_text(&mut self) {
|
||||
while let Some(ConversationEvent::Text { content }) = self.events.last() {
|
||||
if content.is_empty() {
|
||||
self.events.pop();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Edit mode and exit methods =====
|
||||
|
||||
/// Start edit mode for refinement
|
||||
pub fn start_edit_mode(&mut self) {
|
||||
self.confirmation_pending = false;
|
||||
self.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
/// Retry after error
|
||||
pub fn retry(&mut self) {
|
||||
self.error = None;
|
||||
self.mode = AppMode::Generating;
|
||||
/// Store session ID from server response
|
||||
pub fn store_session_id(&mut self, session_id: String) {
|
||||
self.session_id = Some(session_id);
|
||||
}
|
||||
|
||||
/// Handle a slash command
|
||||
@@ -445,85 +331,247 @@ impl AppState {
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Query methods =====
|
||||
/// Ephemeral UI/presentation state
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Interaction {
|
||||
/// Current application mode
|
||||
pub mode: AppMode,
|
||||
/// Whether the input is blank
|
||||
pub is_input_blank: bool,
|
||||
/// True when user has pressed Enter once on a dangerous command
|
||||
pub confirmation_pending: bool,
|
||||
/// Current streaming status
|
||||
pub streaming_status: Option<StreamingStatus>,
|
||||
/// Whether current turn was interrupted by user
|
||||
pub was_interrupted: bool,
|
||||
/// Current error message
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Get the most recent command from events
|
||||
pub fn current_command(&self) -> Option<&str> {
|
||||
self.events.iter().rev().find_map(|e| e.as_command())
|
||||
impl Interaction {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
mode: AppMode::Input,
|
||||
is_input_blank: false,
|
||||
confirmation_pending: false,
|
||||
streaming_status: None,
|
||||
was_interrupted: false,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Top-level session state
|
||||
///
|
||||
/// Decomposed into `Conversation` (event log + session ID) and
|
||||
/// `Interaction` (ephemeral UI state). Session methods that cross
|
||||
/// both sub-structs live here.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Session {
|
||||
pub conversation: Conversation,
|
||||
pub interaction: Interaction,
|
||||
/// Tracks all tool calls through their full lifecycle.
|
||||
pub tool_tracker: ToolTracker,
|
||||
/// Whether the session is running inside a git project (for permission UI labels).
|
||||
pub in_git_project: bool,
|
||||
/// Exit action (set when exiting)
|
||||
pub exit_action: Option<ExitAction>,
|
||||
/// Abort handle for the active streaming task, if any
|
||||
pub stream_abort: Option<AbortHandle>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(in_git_project: bool) -> Self {
|
||||
Self {
|
||||
conversation: Conversation::new(),
|
||||
interaction: Interaction::new(),
|
||||
tool_tracker: ToolTracker::new(),
|
||||
in_git_project,
|
||||
exit_action: None,
|
||||
stream_abort: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the most recent command is marked dangerous
|
||||
pub fn is_current_command_dangerous(&self) -> bool {
|
||||
self.events
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|e| {
|
||||
if let ConversationEvent::ToolCall { name, input, .. } = e
|
||||
&& name == "suggest_command"
|
||||
{
|
||||
let danger_level = input
|
||||
.get("danger")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("low");
|
||||
return Some(
|
||||
danger_level == "high" || danger_level == "medium" || danger_level == "med",
|
||||
);
|
||||
}
|
||||
None
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
// ===== Generation lifecycle methods =====
|
||||
|
||||
/// Count non-suggest_command tool calls since the last user message
|
||||
pub fn tool_count_since_last_user(&self) -> usize {
|
||||
let last_user_idx = self
|
||||
/// Start generating from submitted input
|
||||
pub fn start_generating(&mut self, input: String) {
|
||||
self.conversation
|
||||
.events
|
||||
.iter()
|
||||
.rposition(|e| matches!(e, ConversationEvent::UserMessage { .. }))
|
||||
.unwrap_or(0);
|
||||
.push(ConversationEvent::UserMessage { content: input });
|
||||
self.interaction.mode = AppMode::Generating;
|
||||
}
|
||||
|
||||
let mut completed = 0;
|
||||
let mut in_flight = false;
|
||||
/// Generation error occurred
|
||||
#[expect(dead_code)]
|
||||
pub fn generation_error(&mut self, error: String) {
|
||||
self.interaction.error = Some(error);
|
||||
self.interaction.mode = AppMode::Error;
|
||||
}
|
||||
|
||||
for event in &self.events[last_user_idx..] {
|
||||
match event {
|
||||
ConversationEvent::ToolCall { name, .. } if name != "suggest_command" => {
|
||||
if in_flight {
|
||||
completed += 1;
|
||||
}
|
||||
in_flight = true;
|
||||
}
|
||||
ConversationEvent::ToolResult { .. } => {
|
||||
if in_flight {
|
||||
completed += 1;
|
||||
in_flight = false;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
/// Cancel during generation
|
||||
pub fn cancel_generation(&mut self) {
|
||||
if let Some(abort) = self.stream_abort.take() {
|
||||
abort.abort();
|
||||
}
|
||||
if let Some(ConversationEvent::UserMessage { .. }) = self.conversation.events.last() {
|
||||
self.conversation.events.pop();
|
||||
}
|
||||
self.interaction.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
// ===== Streaming lifecycle methods =====
|
||||
|
||||
/// Start streaming response.
|
||||
/// Pushes an empty Text event that will be mutated in-place as chunks arrive.
|
||||
pub fn start_streaming(&mut self) {
|
||||
self.conversation.events.push(ConversationEvent::Text {
|
||||
content: String::new(),
|
||||
});
|
||||
self.interaction.streaming_status = None;
|
||||
self.interaction.was_interrupted = false;
|
||||
self.interaction.mode = AppMode::Streaming;
|
||||
}
|
||||
|
||||
/// Update streaming status from SSE event
|
||||
pub fn update_streaming_status(&mut self, status: &str) {
|
||||
self.interaction.streaming_status = Some(StreamingStatus::from_status_str(status));
|
||||
}
|
||||
|
||||
/// Cancel streaming with context preservation
|
||||
pub fn cancel_streaming(&mut self) {
|
||||
if let Some(abort) = self.stream_abort.take() {
|
||||
abort.abort();
|
||||
}
|
||||
self.interaction.was_interrupted = true;
|
||||
|
||||
if let Some(content) = self.conversation.streaming_content_mut() {
|
||||
let trimmed = content.trim_start().to_string();
|
||||
if trimmed.is_empty() {
|
||||
// Remove the empty text event
|
||||
*content = String::new();
|
||||
} else {
|
||||
*content = format!("{trimmed}\n\n[User cancelled this generation]");
|
||||
}
|
||||
}
|
||||
// Remove trailing empty Text events
|
||||
self.conversation.remove_empty_trailing_text();
|
||||
|
||||
self.interaction.streaming_status = None;
|
||||
self.interaction.confirmation_pending = false;
|
||||
self.interaction.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
/// Add a tool call event during streaming.
|
||||
/// The current streaming text is already in events, so we just push the tool call.
|
||||
pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) {
|
||||
// Trim the streaming text event
|
||||
if let Some(content) = self.conversation.streaming_content_mut() {
|
||||
let trimmed = content.trim_start().to_string();
|
||||
*content = trimmed;
|
||||
}
|
||||
self.conversation.remove_empty_trailing_text();
|
||||
|
||||
let is_suggest_command = name == "suggest_command";
|
||||
self.conversation
|
||||
.events
|
||||
.push(ConversationEvent::ToolCall { id, name, input });
|
||||
|
||||
if is_suggest_command {
|
||||
self.interaction.streaming_status = None;
|
||||
self.interaction.mode = AppMode::Input;
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize streaming — trim the accumulated text and change mode
|
||||
pub fn finalize_streaming(&mut self) {
|
||||
if let Some(content) = self.conversation.streaming_content_mut() {
|
||||
let trimmed = content.trim_start().to_string();
|
||||
*content = trimmed;
|
||||
}
|
||||
self.conversation.remove_empty_trailing_text();
|
||||
self.interaction.streaming_status = None;
|
||||
self.interaction.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
/// Streaming error — remove the partial text event
|
||||
pub fn streaming_error(&mut self, error: String) {
|
||||
self.conversation.remove_empty_trailing_text();
|
||||
self.interaction.error = Some(error);
|
||||
self.interaction.mode = AppMode::Error;
|
||||
}
|
||||
|
||||
pub(crate) fn handle_client_tool_call(
|
||||
&mut self,
|
||||
id: String,
|
||||
tool: ClientToolCall,
|
||||
input: serde_json::Value,
|
||||
) {
|
||||
let desc = tool.descriptor();
|
||||
let name = desc.canonical_names[0].to_string();
|
||||
|
||||
self.tool_tracker.insert(id.clone(), tool);
|
||||
|
||||
// Add the ToolCall event to the conversation immediately so it appears
|
||||
// in the view. Preview data is sourced from tool_tracker.
|
||||
self.conversation
|
||||
.events
|
||||
.push(ConversationEvent::ToolCall { id, name, input });
|
||||
|
||||
// Client tool calls can only happen at the last part of a turn
|
||||
self.interaction.streaming_status = None;
|
||||
self.interaction.mode = AppMode::Input;
|
||||
}
|
||||
|
||||
/// Retry after error
|
||||
pub fn retry(&mut self) {
|
||||
self.interaction.error = None;
|
||||
self.interaction.mode = AppMode::Generating;
|
||||
}
|
||||
|
||||
// ===== Tool lifecycle methods =====
|
||||
|
||||
/// Finish a tool call: transition tracker to Completed, push ToolResult to conversation.
|
||||
///
|
||||
/// For shell commands, captures the final preview from the ExecutingWithPreview phase
|
||||
/// and patches exit_code/interrupted from the authoritative ToolOutcome.
|
||||
pub fn finish_tool_call(&mut self, tool_id: &str, outcome: ToolOutcome) {
|
||||
let mut preview = self.tool_tracker.get(tool_id).and_then(|t| t.preview());
|
||||
|
||||
// Patch preview with authoritative outcome data (handles race where
|
||||
// final VT100 update hasn't been applied yet).
|
||||
if let Some(ref mut p) = preview
|
||||
&& let ToolOutcome::Structured {
|
||||
exit_code,
|
||||
interrupted,
|
||||
..
|
||||
} = &outcome
|
||||
{
|
||||
p.interrupted = *interrupted;
|
||||
if p.exit_code.is_none() {
|
||||
p.exit_code = *exit_code;
|
||||
}
|
||||
}
|
||||
|
||||
completed
|
||||
}
|
||||
// Transition tracker entry to Completed
|
||||
if let Some(tracked) = self.tool_tracker.get_mut(tool_id) {
|
||||
tracked.complete(preview);
|
||||
}
|
||||
|
||||
/// Check if any turn in the conversation has a command
|
||||
pub fn has_any_command(&self) -> bool {
|
||||
self.events.iter().any(|e| {
|
||||
if let ConversationEvent::ToolCall { name, input, .. } = e {
|
||||
name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
let content = outcome.format_for_llm();
|
||||
let is_error = outcome.is_error();
|
||||
self.conversation
|
||||
.add_tool_result(tool_id.to_string(), content, is_error);
|
||||
}
|
||||
|
||||
/// Get the footer text for current mode
|
||||
pub fn footer_text(&self) -> &'static str {
|
||||
match self.mode {
|
||||
match self.interaction.mode {
|
||||
AppMode::Input => {
|
||||
if self.has_any_command() && self.is_input_blank {
|
||||
if self.confirmation_pending {
|
||||
if self.conversation.has_any_command() && self.interaction.is_input_blank {
|
||||
if self.interaction.confirmation_pending {
|
||||
"[Enter] Confirm dangerous command [Esc] Cancel"
|
||||
} else {
|
||||
"[Enter] Execute suggested command [Tab] Insert Command"
|
||||
@@ -542,9 +590,3 @@ impl AppState {
|
||||
self.exit_action.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
//! View function that builds the eye-declare element tree from app state.
|
||||
|
||||
use eye_declare::{
|
||||
Cells, Column, Elements, HStack, Span, Spinner, Text, View, WidthConstraint, element,
|
||||
BorderType, Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport,
|
||||
WidthConstraint, element,
|
||||
};
|
||||
use ratatui_core::style::{Color, Modifier, Style};
|
||||
|
||||
use crate::tools::{ClientToolCall, TrackedTool};
|
||||
use crate::tui::components::select::SelectOption;
|
||||
use crate::tui::events::{AiTuiEvent, PermissionResult};
|
||||
|
||||
use super::components::atuin_ai::AtuinAi;
|
||||
use super::components::input_box::InputBox;
|
||||
use super::components::markdown::Markdown;
|
||||
use super::state::{AppMode, AppState};
|
||||
use super::components::select::Select;
|
||||
use super::state::{AppMode, Session};
|
||||
|
||||
mod turn;
|
||||
|
||||
@@ -20,23 +26,25 @@ mod turn;
|
||||
/// - Error display (if in error state)
|
||||
/// - Spacer
|
||||
/// - Input box (bordered, with contextual keybindings)
|
||||
pub fn ai_view(state: &AppState) -> Elements {
|
||||
let mut turn_builder = turn::TurnBuilder::new();
|
||||
pub(crate) fn ai_view(state: &Session) -> Elements {
|
||||
let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker);
|
||||
|
||||
for event in &state.events {
|
||||
for event in &state.conversation.events {
|
||||
turn_builder.add_event(event);
|
||||
}
|
||||
let turns = turn_builder.build();
|
||||
|
||||
let busy = state.mode == AppMode::Streaming || state.mode == AppMode::Generating;
|
||||
let busy = state.interaction.mode == AppMode::Streaming
|
||||
|| state.interaction.mode == AppMode::Generating;
|
||||
let last_index = turns.len().saturating_sub(1);
|
||||
|
||||
element! {
|
||||
AtuinAi(
|
||||
mode: state.mode,
|
||||
has_command: state.has_any_command(),
|
||||
is_input_blank: state.is_input_blank,
|
||||
pending_confirmation: state.confirmation_pending,
|
||||
mode: state.interaction.mode,
|
||||
has_command: state.conversation.has_any_command(),
|
||||
is_input_blank: state.interaction.is_input_blank,
|
||||
pending_confirmation: state.interaction.confirmation_pending,
|
||||
has_executing_preview: state.tool_tracker.has_executing_preview(),
|
||||
) {
|
||||
#(for (index, turn) in turns.iter().enumerate() {
|
||||
#(match turn {
|
||||
@@ -53,29 +61,98 @@ pub fn ai_view(state: &AppState) -> Elements {
|
||||
})
|
||||
|
||||
#(if !state.is_exiting() {
|
||||
View(key: "input-box", padding_top: Cells::from(1)) {
|
||||
InputBox(
|
||||
key: "input",
|
||||
title: "Generate a command or ask a question",
|
||||
title_right: "Atuin AI",
|
||||
footer: state.footer_text(),
|
||||
active: state.mode == AppMode::Input && !state.confirmation_pending,
|
||||
)
|
||||
|
||||
#(if state.is_input_blank && state.has_any_command() && state.mode == AppMode::Input {
|
||||
#(if state.confirmation_pending {
|
||||
Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) }
|
||||
} else {
|
||||
Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) }
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
#(input_view(state))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn input_view(state: &Session) -> Elements {
|
||||
let asking_tool = state.tool_tracker.asking_for_permission();
|
||||
let in_git_project = state.in_git_project;
|
||||
|
||||
element! {
|
||||
#(if let Some(tc) = asking_tool {
|
||||
#(tool_call_view(tc, in_git_project))
|
||||
})
|
||||
|
||||
#(if asking_tool.is_none() {
|
||||
View(key: "input-box", padding_top: Cells::from(1)) {
|
||||
InputBox(
|
||||
key: "input",
|
||||
title: "Generate a command or ask a question",
|
||||
title_right: "Atuin AI",
|
||||
footer: state.footer_text(),
|
||||
active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending,
|
||||
)
|
||||
|
||||
#(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input {
|
||||
#(if state.interaction.confirmation_pending {
|
||||
Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) }
|
||||
} else {
|
||||
Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) }
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_call_view(tool_call: &TrackedTool, in_git_project: bool) -> Elements {
|
||||
let verb = tool_call.tool.descriptor().display_verb;
|
||||
let tool_desc = match &tool_call.tool {
|
||||
ClientToolCall::Read(tool) => tool.path.display().to_string(),
|
||||
ClientToolCall::Write(tool) => tool.path.display().to_string(),
|
||||
ClientToolCall::Shell(tool) => tool.command.clone(),
|
||||
ClientToolCall::AtuinHistory(tool) => tool.query.clone(),
|
||||
};
|
||||
|
||||
let dir_label = if in_git_project {
|
||||
"Always allow in this workspace"
|
||||
} else {
|
||||
"Always allow in this directory"
|
||||
};
|
||||
|
||||
element! {
|
||||
View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) {
|
||||
Text {
|
||||
Span(text: format!("Atuin AI would like to {}: ", verb), style: Style::default())
|
||||
Span(text: &tool_desc, style: Style::default().fg(Color::Yellow))
|
||||
}
|
||||
View(padding_left: Cells::from(2)) {
|
||||
Select(options: [
|
||||
SelectOption::builder()
|
||||
.label("Allow")
|
||||
.value("allow")
|
||||
.build(),
|
||||
SelectOption::builder()
|
||||
.label(dir_label)
|
||||
.value("always-allow-in-dir")
|
||||
.build(),
|
||||
SelectOption::builder()
|
||||
.label("Always allow")
|
||||
.value("always-allow")
|
||||
.build(),
|
||||
SelectOption::builder()
|
||||
.label("Deny")
|
||||
.value("deny")
|
||||
.build(),
|
||||
], on_select: Box::new(move |option: &SelectOption| {
|
||||
let value = match option.value.as_str() {
|
||||
"allow" => PermissionResult::Allow,
|
||||
"always-allow-in-dir" => PermissionResult::AlwaysAllowInDir,
|
||||
"always-allow" => PermissionResult::AlwaysAllow,
|
||||
"deny" => PermissionResult::Deny,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Some(AiTuiEvent::SelectPermission(value))
|
||||
}) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements {
|
||||
let label_style = Style::default()
|
||||
.fg(Color::Cyan)
|
||||
@@ -86,7 +163,7 @@ fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements {
|
||||
element! {
|
||||
View(padding_top: Cells::from(padding)) {
|
||||
Text {
|
||||
Span(text: "You", style: label_style)
|
||||
Span(text: " You ", style: label_style.reversed())
|
||||
}
|
||||
#(for event in events {
|
||||
#(match event {
|
||||
@@ -114,9 +191,9 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
|
||||
element! {
|
||||
View {
|
||||
Spinner(
|
||||
label: "Atuin AI",
|
||||
label_style: label_style,
|
||||
done_label_style: label_style,
|
||||
label: " Atuin AI ",
|
||||
label_style: label_style.reversed(),
|
||||
done_label_style: label_style.reversed(),
|
||||
hide_checkmark: true,
|
||||
label_first: true,
|
||||
done: !busy,
|
||||
@@ -136,6 +213,52 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements {
|
||||
turn::UiEvent::SuggestedCommand(details) => {
|
||||
suggested_command_view(details)
|
||||
},
|
||||
turn::UiEvent::ToolCall(details) => {
|
||||
let preview_done = details.preview.as_ref().is_some_and(|p| p.exit_code.is_some() || p.interrupted);
|
||||
let tool_key = details.tool_use_id.clone();
|
||||
|
||||
element! {
|
||||
View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) {
|
||||
#(if let Some(ref preview) = details.preview {
|
||||
View(key: format!("preview-{tool_key}")) {
|
||||
#(preview_spinner_view(&details.name, preview_done))
|
||||
Viewport(
|
||||
key: format!("viewport-{tool_key}"),
|
||||
lines: preview.lines.clone(),
|
||||
height: 10,
|
||||
border: BorderType::Plain,
|
||||
border_style: Style::default().fg(Color::DarkGray),
|
||||
style: Style::default().fg(Color::White),
|
||||
wrap: false,
|
||||
)
|
||||
#(if let Some(code) = preview.exit_code {
|
||||
#(if code == 0 {
|
||||
Text {
|
||||
Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Green))
|
||||
}
|
||||
} else {
|
||||
Text {
|
||||
Span(text: format!("Exit code: {code}"), style: Style::default().fg(Color::Red))
|
||||
}
|
||||
})
|
||||
})
|
||||
#(if preview.interrupted {
|
||||
Text {
|
||||
Span(text: "Interrupted", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
|
||||
}
|
||||
})
|
||||
#(if !preview_done {
|
||||
Text {
|
||||
Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray))
|
||||
}
|
||||
})
|
||||
}
|
||||
} else {
|
||||
#(tool_status_view(&details.name, &details.status))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => element!{}
|
||||
})
|
||||
})
|
||||
@@ -180,6 +303,48 @@ fn tool_summary_view(summary: &turn::ToolSummary) -> Elements {
|
||||
}
|
||||
}
|
||||
|
||||
/// Render a status indicator for a non-preview tool call (e.g. atuin_history, read_file).
|
||||
fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements {
|
||||
match status {
|
||||
turn::ToolResultStatus::Pending => {
|
||||
element! {
|
||||
Spinner(
|
||||
label: format!("Running: {name}"),
|
||||
label_style: Style::default().fg(Color::Yellow),
|
||||
done: false,
|
||||
)
|
||||
}
|
||||
}
|
||||
turn::ToolResultStatus::Success => {
|
||||
element! {
|
||||
Spinner(
|
||||
label: format!("Ran: {name}"),
|
||||
done: true,
|
||||
)
|
||||
}
|
||||
}
|
||||
turn::ToolResultStatus::Error => {
|
||||
element! {
|
||||
Text {
|
||||
Span(text: "✗ ", style: Style::default().fg(Color::Red))
|
||||
Span(text: format!("{name}: denied"), style: Style::default().fg(Color::Red))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Render a spinner/status line for a command preview (shell tools).
|
||||
fn preview_spinner_view(name: &str, done: bool) -> Elements {
|
||||
element! {
|
||||
Spinner(
|
||||
label: if done { format!("Ran: {name}") } else { format!("Running: {name}") },
|
||||
label_style: Style::default().fg(Color::Yellow),
|
||||
done: done,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements {
|
||||
let is_dangerous = matches!(
|
||||
details.danger_level,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::tools::descriptor;
|
||||
use crate::tools::{ToolPreview, ToolTracker};
|
||||
use crate::tui::ConversationEvent;
|
||||
|
||||
/// Server-sent danger level for a suggested command
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum DangerLevel {
|
||||
Low(Option<String>),
|
||||
@@ -37,6 +40,7 @@ impl From<(&String, &String)> for DangerLevel {
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-sent confidence level for a suggested command
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ConfidenceLevel {
|
||||
Low(Option<String>),
|
||||
@@ -85,9 +89,11 @@ pub(crate) enum UiEvent {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ToolCallDetails {
|
||||
tool_use_id: String,
|
||||
name: String,
|
||||
status: ToolResultStatus,
|
||||
pub(crate) tool_use_id: String,
|
||||
pub(crate) name: String,
|
||||
pub(crate) status: ToolResultStatus,
|
||||
pub(crate) is_client: bool,
|
||||
pub(crate) preview: Option<ToolPreview>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -118,16 +124,19 @@ pub(crate) enum UiTurn {
|
||||
OutOfBand { events: Vec<UiEvent> },
|
||||
}
|
||||
|
||||
pub(crate) struct TurnBuilder {
|
||||
pub(crate) struct TurnBuilder<'a> {
|
||||
turns: Vec<UiTurn>,
|
||||
current_turn: Option<UiTurn>,
|
||||
tracker: &'a ToolTracker,
|
||||
}
|
||||
|
||||
impl TurnBuilder {
|
||||
pub(crate) fn new() -> Self {
|
||||
/// A struct to iteratively build [UiTurn] events from [ConversationEvent]s.
|
||||
impl<'a> TurnBuilder<'a> {
|
||||
pub(crate) fn new(tracker: &'a ToolTracker) -> Self {
|
||||
Self {
|
||||
turns: Vec::new(),
|
||||
current_turn: None,
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +183,7 @@ impl TurnBuilder {
|
||||
|
||||
for event in events.drain(..) {
|
||||
match event {
|
||||
UiEvent::ToolCall(details) => {
|
||||
UiEvent::ToolCall(details) if !details.is_client => {
|
||||
pending_tools.push(details);
|
||||
}
|
||||
other => {
|
||||
@@ -306,12 +315,17 @@ impl TurnBuilder {
|
||||
}
|
||||
|
||||
fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) {
|
||||
let is_client = descriptor::by_name(name).is_some_and(|d| d.is_client);
|
||||
let preview = self.tracker.preview_for(id);
|
||||
|
||||
self.start_agent_turn();
|
||||
if let UiTurn::Agent { events } = self.turn_mut_unsafe() {
|
||||
events.push(UiEvent::ToolCall(ToolCallDetails {
|
||||
tool_use_id: id.to_string(),
|
||||
name: name.to_string(),
|
||||
status: ToolResultStatus::Pending,
|
||||
is_client,
|
||||
preview,
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -385,25 +399,15 @@ impl ToolSummary {
|
||||
|
||||
/// Present-tense progressive verb for a tool name (e.g. "Searching...")
|
||||
fn progressive_verb(name: &str) -> String {
|
||||
match name {
|
||||
"search" => "Searching...".into(),
|
||||
"read" | "read_file" => "Reading file...".into(),
|
||||
"write" | "write_file" => "Writing file...".into(),
|
||||
"execute" | "run" | "bash" => "Running command...".into(),
|
||||
"list" | "list_files" => "Listing files...".into(),
|
||||
_ => format!("Running {}...", name.replace('_', " ")),
|
||||
}
|
||||
descriptor::by_name(name)
|
||||
.map(|d| d.progressive_verb.to_string())
|
||||
.unwrap_or_else(|| format!("Running {}...", name.replace('_', " ")))
|
||||
}
|
||||
|
||||
/// Past-tense verb for a tool name (e.g. "Searched")
|
||||
fn past_verb(name: &str) -> String {
|
||||
match name {
|
||||
"search" => "Searched".into(),
|
||||
"read" | "read_file" => "Read file".into(),
|
||||
"write" | "write_file" => "Wrote file".into(),
|
||||
"execute" | "run" | "bash" => "Ran command".into(),
|
||||
"list" | "list_files" => "Listed files".into(),
|
||||
_ => format!("Ran {}", name.replace('_', " ")),
|
||||
}
|
||||
descriptor::by_name(name)
|
||||
.map(|d| d.past_verb.to_string())
|
||||
.unwrap_or_else(|| format!("Ran {}", name.replace('_', " ")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ rand = { workspace = true }
|
||||
shellexpand = "3"
|
||||
sqlx = { workspace = true, features = ["sqlite", "regexp"] }
|
||||
minspan = "0.1.5"
|
||||
regex = "1.10.5"
|
||||
regex = { workspace = true }
|
||||
serde_regex = "1.1.0"
|
||||
fs-err = { workspace = true }
|
||||
sql-builder = { workspace = true }
|
||||
|
||||
@@ -671,6 +671,16 @@ pub struct Ai {
|
||||
/// Configuration for what context is sent in the opening AI request.
|
||||
#[serde(default)]
|
||||
pub opening: AiOpening,
|
||||
|
||||
/// Tool capability flags.
|
||||
#[serde(default)]
|
||||
pub capabilities: AiCapabilities,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct AiCapabilities {
|
||||
/// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission).
|
||||
pub enable_history_search: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Deserialize, Serialize)]
|
||||
|
||||
@@ -18,4 +18,4 @@ crossterm = { workspace = true }
|
||||
eyre = { workspace = true }
|
||||
portable-pty = "0.8"
|
||||
signal-hook = "0.3"
|
||||
vt100 = "0.15"
|
||||
vt100 = { workspace = true }
|
||||
|
||||
@@ -222,7 +222,7 @@ mod app {
|
||||
fn handle_parser_msg(parser: &mut vt100::Parser, msg: ParserMsg) {
|
||||
match msg {
|
||||
ParserMsg::Data(data) => parser.process(&data),
|
||||
ParserMsg::Resize { rows, cols } => parser.set_size(rows, cols),
|
||||
ParserMsg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols),
|
||||
ParserMsg::ScreenRequest(reply_tx) => {
|
||||
let _ = reply_tx.send(encode_screen(parser));
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ tempfile = { workspace = true }
|
||||
shlex = "1.3.0"
|
||||
|
||||
# settings editor with comment and relative ordering preservation
|
||||
toml_edit = "0.25.4"
|
||||
toml_edit = { workspace = true }
|
||||
|
||||
[target.'cfg(any(target_os = "windows", target_os = "macos"))'.dependencies]
|
||||
arboard = { version = "3.4", optional = true, default-features = false }
|
||||
@@ -105,6 +105,12 @@ arboard = { version = "3.4", optional = true, default-features = false, features
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
daemonize = "0.5.0"
|
||||
|
||||
# Enable tree-sitter shell parsing on platforms where tree-sitter's bundled C
|
||||
# compiles cleanly. tree-sitter 0.26's portable/endian.h fails on illumos,
|
||||
# Windows cross-compiles, and potentially other exotic targets.
|
||||
[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dependencies]
|
||||
atuin-ai = { path = "../atuin-ai", version = "18.13.6", optional = true, default-features = false, features = ["tree-sitter"] }
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows-sys = { version = "0.61.2", features = ["Win32_System_Console"] }
|
||||
|
||||
|
||||
@@ -154,11 +154,21 @@ impl Cmd {
|
||||
daemon::daemonize_current_process()?;
|
||||
}
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
#[cfg(feature = "ai")]
|
||||
let mut runtime = if matches!(&self, Self::Ai(_)) {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
} else {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "ai"))]
|
||||
let mut runtime = tokio::runtime::Builder::new_current_thread();
|
||||
|
||||
let runtime = runtime.enable_all().build().unwrap();
|
||||
|
||||
// For non-history commands, we want to initialize logging and the theme manager before
|
||||
// doing anything else. History commands are performance-sensitive and run before and after
|
||||
// every shell command, so we want to skip any unnecessary initialization for them.
|
||||
let settings = Settings::new().wrap_err("could not load client settings")?;
|
||||
let theme_manager = theme::ThemeManager::new(settings.theme.debug, None);
|
||||
let res = runtime.block_on(self.run_inner(settings, theme_manager));
|
||||
|
||||
+29
-12
@@ -8,6 +8,35 @@ Default: `false`
|
||||
|
||||
Whether or not the AI feature are enabled. When set to `false`, the question mark keybinding will output a message with instructions to run `atuin setup` to enable the feature.
|
||||
|
||||
### endpoint
|
||||
|
||||
Default: `null`
|
||||
|
||||
The address of the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints.
|
||||
|
||||
### api_token
|
||||
|
||||
Default: `null`
|
||||
|
||||
The API token for the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints.
|
||||
|
||||
## Capabilities
|
||||
|
||||
Settings that control what capabilities are sent to the LLM. These are specified under `[ai.capabilities]`.
|
||||
|
||||
### enable_history_search
|
||||
|
||||
Default: `true`
|
||||
|
||||
Whether or not to include the "history search" capability in the context sent to the LLM. This allows the AI to request to search your command history for relevant commands when generating suggestions or answering questions.
|
||||
|
||||
**Example config**
|
||||
|
||||
```toml
|
||||
[ai.capabilities]
|
||||
enable_history_search = false
|
||||
```
|
||||
|
||||
## Opening context
|
||||
|
||||
Settings that control what context is sent in the opening AI request. These are specified under `[ai.opening]`.
|
||||
@@ -37,15 +66,3 @@ Whether or not to send your previous command as context in the initial request,
|
||||
[ai.opening]
|
||||
send_last_command = true
|
||||
```
|
||||
|
||||
### endpoint
|
||||
|
||||
Default: `null`
|
||||
|
||||
The address of the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints.
|
||||
|
||||
### api_token
|
||||
|
||||
Default: `null`
|
||||
|
||||
The API token for the Atuin AI endpoint. Used for AI features like command generation. Most users will not need this setting; it is only necessary for custom AI endpoints.
|
||||
|
||||
Reference in New Issue
Block a user