diff --git a/src/config.rs b/src/config.rs index a456013..5fbece4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,10 @@ -use clap::{arg, parser::ValueSource, value_parser, ArgMatches, Command}; -use serde::Deserialize; -use serde_default::DefaultFromSerde; -use std::{ - io::ErrorKind, - path::{Path, PathBuf}, - time::Duration, -}; +mod defaults; +mod file_config; use crate::BoxedError; +use clap::{arg, value_parser, Command}; +use file_config::FileConfig; +use std::{path::PathBuf, time::Duration}; pub struct Config { pub port: u32, @@ -19,142 +16,6 @@ pub struct Config { pub active_window_bucket_name: String, } -fn default_idle_timeout_seconds() -> u32 { - 180 -} -fn default_poll_time_idle_seconds() -> u32 { - 5 -} -fn default_poll_time_window_seconds() -> u32 { - 1 -} -fn default_port() -> u32 { - 5600 -} -fn default_host() -> String { - "localhost".to_string() -} - -#[derive(Deserialize, DefaultFromSerde)] -struct ServerConfig { - #[serde(default = "default_port")] - port: u32, - #[serde(default = "default_host")] - host: String, -} - -#[derive(Deserialize, DefaultFromSerde)] -struct ClientConfig { - #[serde(default = "default_idle_timeout_seconds")] - idle_timeout_seconds: u32, - #[serde(default = "default_poll_time_idle_seconds")] - poll_time_idle_seconds: u32, - #[serde(default = "default_poll_time_window_seconds")] - poll_time_window_seconds: u32, -} - -#[derive(Deserialize, Default)] -struct FileConfig { - #[serde(default)] - server: ServerConfig, - #[serde(default)] - client: ClientConfig, -} - -impl FileConfig { - fn new(matches: &ArgMatches) -> Result { - let mut config_path: PathBuf = dirs::config_dir().ok_or("Config directory is unknown")?; - config_path.push("awatcher"); - config_path.push("config.toml"); - if matches.contains_id("config") { - let config_file = matches.get_one::("config"); - if let Some(path) = config_file { - if let Err(e) = std::fs::metadata(path) { - warn!("Invalid config filename, using the default config: {e}"); - } else { - config_path = Path::new(path).to_path_buf(); - } - } - } - - if config_path.exists() { - debug!("Reading config at {}", config_path.display()); - let config_content = std::fs::read_to_string(config_path) - .map_err(|e| format!("Impossible to read config file: {e}"))?; - - Ok(toml::from_str(&config_content)?) - } else { - let config = format!( - r#"# The commented values are the defaults on the file creation -[server] -# port = {} -# host = "{}" -[awatcher] -# idle-timeout-seconds={} -# poll-time-idle-seconds={} -# poll-time-window-seconds={} -"#, - default_port(), - default_host(), - default_idle_timeout_seconds(), - default_poll_time_idle_seconds(), - default_poll_time_window_seconds(), - ); - let error = std::fs::create_dir(config_path.parent().unwrap()); - if let Err(e) = error { - if e.kind() != ErrorKind::AlreadyExists { - Err(e)?; - } - } - debug!("Creading config at {}", config_path.display()); - std::fs::write(config_path, config)?; - - Ok(Self::default()) - } - } - - fn merge_cli(&mut self, matches: &ArgMatches) { - self.client.poll_time_idle_seconds = get_arg_value( - "poll-time-idle", - matches, - self.client.poll_time_idle_seconds, - ); - self.client.poll_time_window_seconds = get_arg_value( - "poll-time-window", - matches, - self.client.poll_time_window_seconds, - ); - self.client.idle_timeout_seconds = - get_arg_value("idle-timeout", matches, self.client.idle_timeout_seconds); - - self.server.port = get_arg_value("port", matches, self.server.port); - self.server.host = get_arg_value("host", matches, self.server.host.clone()); - } - - fn get_idle_timeout(&self) -> Duration { - Duration::from_secs(u64::from(self.client.idle_timeout_seconds)) - } - - fn get_poll_time_idle(&self) -> Duration { - Duration::from_secs(u64::from(self.client.poll_time_idle_seconds)) - } - - fn get_poll_time_window(&self) -> Duration { - Duration::from_secs(u64::from(self.client.poll_time_window_seconds)) - } -} - -fn get_arg_value(id: &str, matches: &ArgMatches, config_value: T) -> T -where - T: Clone + Send + Sync + 'static, -{ - if let Some(ValueSource::CommandLine) = matches.value_source(id) { - matches.get_one::(id).unwrap().clone() - } else { - config_value - } -} - impl Config { pub fn from_cli() -> Result { let matches = Command::new("Activity Watcher") @@ -164,24 +25,23 @@ impl Config { arg!(-c --config "Custom config file").value_parser(value_parser!(PathBuf)), arg!(--port "Custom server port") .value_parser(value_parser!(u32)) - .default_value(default_port().to_string()), + .default_value(defaults::port().to_string()), arg!(--host "Custom server host") .value_parser(value_parser!(String)) - .default_value(default_host()), + .default_value(defaults::host()), arg!(--"idle-timeout" "Time of inactivity to consider the user idle") .value_parser(value_parser!(u32)) - .default_value(default_idle_timeout_seconds().to_string()), + .default_value(defaults::idle_timeout_seconds().to_string()), arg!(--"poll-time-idle" "Period between sending heartbeats to the server for idle activity") .value_parser(value_parser!(u32)) - .default_value(default_poll_time_idle_seconds().to_string()), + .default_value(defaults::poll_time_idle_seconds().to_string()), arg!(--"poll-time-window" "Period between sending heartbeats to the server for idle activity") .value_parser(value_parser!(u32)) - .default_value(default_poll_time_window_seconds().to_string()), + .default_value(defaults::poll_time_window_seconds().to_string()), ]) .get_matches(); - let mut config = FileConfig::new(&matches)?; - config.merge_cli(&matches); + let config = FileConfig::new(&matches)?; let hostname = gethostname::gethostname().into_string().unwrap(); let idle_bucket_name = format!("aw-watcher-afk_{hostname}"); @@ -190,9 +50,9 @@ impl Config { Ok(Self { port: config.server.port, host: config.server.host.clone(), - idle_timeout: config.get_idle_timeout(), - poll_time_idle: config.get_poll_time_idle(), - poll_time_window: config.get_poll_time_window(), + idle_timeout: config.client.get_idle_timeout(), + poll_time_idle: config.client.get_poll_time_idle(), + poll_time_window: config.client.get_poll_time_window(), idle_bucket_name, active_window_bucket_name, }) diff --git a/src/config/defaults.rs b/src/config/defaults.rs new file mode 100644 index 0000000..4031aa3 --- /dev/null +++ b/src/config/defaults.rs @@ -0,0 +1,15 @@ +pub fn idle_timeout_seconds() -> u32 { + 180 +} +pub fn poll_time_idle_seconds() -> u32 { + 5 +} +pub fn poll_time_window_seconds() -> u32 { + 1 +} +pub fn port() -> u32 { + 5600 +} +pub fn host() -> String { + "localhost".to_string() +} diff --git a/src/config/file_config.rs b/src/config/file_config.rs new file mode 100644 index 0000000..67396d3 --- /dev/null +++ b/src/config/file_config.rs @@ -0,0 +1,137 @@ +use clap::{parser::ValueSource, ArgMatches}; +use serde::Deserialize; +use serde_default::DefaultFromSerde; +use std::{ + io::ErrorKind, + path::{Path, PathBuf}, + time::Duration, +}; + +use crate::{config::defaults, BoxedError}; + +#[derive(Deserialize, DefaultFromSerde)] +pub struct ServerConfig { + #[serde(default = "defaults::port")] + pub port: u32, + #[serde(default = "defaults::host")] + pub host: String, +} + +#[derive(Deserialize, DefaultFromSerde)] +pub struct ClientConfig { + #[serde(default = "defaults::idle_timeout_seconds")] + idle_timeout_seconds: u32, + #[serde(default = "defaults::poll_time_idle_seconds")] + poll_time_idle_seconds: u32, + #[serde(default = "defaults::poll_time_window_seconds")] + poll_time_window_seconds: u32, +} + +impl ClientConfig { + pub fn get_idle_timeout(&self) -> Duration { + Duration::from_secs(u64::from(self.idle_timeout_seconds)) + } + + pub fn get_poll_time_idle(&self) -> Duration { + Duration::from_secs(u64::from(self.poll_time_idle_seconds)) + } + + pub fn get_poll_time_window(&self) -> Duration { + Duration::from_secs(u64::from(self.poll_time_window_seconds)) + } +} + +#[derive(Deserialize, Default)] +pub struct FileConfig { + #[serde(default)] + pub server: ServerConfig, + #[serde(default)] + #[serde(rename = "awatcher")] + pub client: ClientConfig, +} + +impl FileConfig { + pub fn new(matches: &ArgMatches) -> Result { + let mut config_path: PathBuf = dirs::config_dir().ok_or("Config directory is unknown")?; + config_path.push("awatcher"); + config_path.push("config.toml"); + if matches.contains_id("config") { + let config_file = matches.get_one::("config"); + if let Some(path) = config_file { + if let Err(e) = std::fs::metadata(path) { + warn!("Invalid config filename, using the default config: {e}"); + } else { + config_path = Path::new(path).to_path_buf(); + } + } + } + + let mut config = if config_path.exists() { + debug!("Reading config at {}", config_path.display()); + let config_content = std::fs::read_to_string(config_path) + .map_err(|e| format!("Impossible to read config file: {e}"))?; + + toml::from_str(&config_content)? + } else { + let config = format!( + r#"# The commented values are the defaults on the file creation +[server] +# port = {} +# host = "{}" + +[awatcher] +# idle-timeout-seconds={} +# poll-time-idle-seconds={} +# poll-time-window-seconds={} +"#, + defaults::port(), + defaults::host(), + defaults::idle_timeout_seconds(), + defaults::poll_time_idle_seconds(), + defaults::poll_time_window_seconds(), + ); + let error = std::fs::create_dir(config_path.parent().unwrap()); + if let Err(e) = error { + if e.kind() != ErrorKind::AlreadyExists { + Err(e)?; + } + } + debug!("Creading config at {}", config_path.display()); + std::fs::write(config_path, config)?; + + Self::default() + }; + config.merge_cli(matches); + + Ok(config) + } + + fn merge_cli(&mut self, matches: &ArgMatches) { + self.client.poll_time_idle_seconds = get_arg_value( + "poll-time-idle", + matches, + self.client.poll_time_idle_seconds, + ); + self.client.poll_time_window_seconds = get_arg_value( + "poll-time-window", + matches, + self.client.poll_time_window_seconds, + ); + self.client.idle_timeout_seconds = + get_arg_value("idle-timeout", matches, self.client.idle_timeout_seconds); + + self.server.port = get_arg_value("port", matches, self.server.port); + self.server.host = get_arg_value("host", matches, self.server.host.clone()); + } +} + +fn get_arg_value(id: &str, matches: &ArgMatches, config_value: T) -> T +where + T: Clone + Send + Sync + 'static, +{ + if let Some(ValueSource::CommandLine) = matches.value_source(id) { + matches.get_one::(id).unwrap().clone() + } else { + config_value + } +}