feat(engine): add units, currency, datetime, variables, and functions modules
Extracted and integrated unique feature modules from Epic 2-6 branches: - units/: 200+ unit conversions across 14 categories, SI prefixes (nano-tera), CSS/screen units, binary vs decimal data, custom units - currency/: fiat (180+ currencies, cached rates, offline fallback), crypto (63 coins, CoinGecko), symbol recognition, rate caching - datetime/: date/time math, 150+ city timezone mappings (chrono-tz), business day calculations, unix timestamps, relative expressions - variables/: line references (lineN, #N, prev/ans), section aggregators (sum/total/avg/min/max/count), subtotals, autocomplete - functions/: trig, log, combinatorics, financial, rounding, list operations (min/max/gcd/lcm), video timecodes 585 tests passing across workspace.
This commit is contained in:
@@ -4,7 +4,26 @@
|
||||
"Bash(find:*)",
|
||||
"Bash(ls:*)",
|
||||
"Bash(./run-pipeline.sh --phase1 --dry-run 2>&1 | sed 's/\\\\x1b\\\\[[0-9;]*m//g')",
|
||||
"Bash(git branch:*)"
|
||||
"Bash(git branch:*)",
|
||||
"Bash(kill 5354 5355 5360)",
|
||||
"Bash(git worktree:*)",
|
||||
"Bash(while read:*)",
|
||||
"Bash(do git:*)",
|
||||
"Bash(done)",
|
||||
"Bash(git ls-tree:*)",
|
||||
"Bash(cargo build:*)",
|
||||
"Bash(source ~/.zshrc)",
|
||||
"Bash(source \"$HOME/.cargo/env\")",
|
||||
"Bash(export PATH=\"$HOME/.cargo/bin:$PATH\")",
|
||||
"Bash($HOME/.cargo/bin/cargo build:*)",
|
||||
"Bash($HOME/.cargo/bin/cargo test:*)",
|
||||
"Bash(git status:*)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(git rm:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(/Users/cassel/.cargo/bin/cargo test:*)",
|
||||
"Bash(tee /tmp/test-output.txt)",
|
||||
"Bash(echo \"EXIT: $?\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
857
Cargo.lock
generated
857
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,12 @@ edition = "2021"
|
||||
crate-type = ["cdylib", "staticlib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
chrono = "0.4"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
chrono-tz = "0.10"
|
||||
dashu = "0.4"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
ureq = { version = "2", features = ["json"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
@@ -83,6 +83,18 @@ pub enum ExprKind {
|
||||
name: String,
|
||||
value: Box<Expr>,
|
||||
},
|
||||
|
||||
/// Line reference: `line1`, `#1` (1-indexed line number)
|
||||
LineRef(usize),
|
||||
|
||||
/// Previous-line reference: `prev`, `ans`
|
||||
PrevRef,
|
||||
|
||||
/// Function call: `sqrt(4)`, `abs(-5)`
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: Vec<Expr>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
|
||||
633
calcpad-engine/src/currency/crypto.rs
Normal file
633
calcpad-engine/src/currency/crypto.rs
Normal file
@@ -0,0 +1,633 @@
|
||||
//! Cryptocurrency rate provider with CoinGecko API integration.
|
||||
//!
|
||||
//! Supports 60+ coins with disk caching and offline fallback.
|
||||
//! All rates are stored as USD prices.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for the cryptocurrency provider.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CryptoProviderConfig {
|
||||
/// CoinGecko API base URL.
|
||||
pub api_url: String,
|
||||
/// Path to the disk cache file.
|
||||
pub cache_path: PathBuf,
|
||||
/// How often to refresh rates (default: 1 hour).
|
||||
pub refresh_interval: Duration,
|
||||
}
|
||||
|
||||
impl Default for CryptoProviderConfig {
|
||||
fn default() -> Self {
|
||||
let cache_dir = dirs_cache_path();
|
||||
Self {
|
||||
api_url: "https://api.coingecko.com/api/v3".to_string(),
|
||||
cache_path: cache_dir.join("crypto_cache.json"),
|
||||
refresh_interval: Duration::from_secs(3600),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dirs_cache_path() -> PathBuf {
|
||||
if let Some(home) = std::env::var_os("HOME") {
|
||||
PathBuf::from(home).join(".calcpad")
|
||||
} else {
|
||||
PathBuf::from(".calcpad")
|
||||
}
|
||||
}
|
||||
|
||||
/// A single cryptocurrency rate entry.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CryptoRate {
|
||||
pub symbol: String,
|
||||
pub name: String,
|
||||
pub usd_price: f64,
|
||||
}
|
||||
|
||||
/// Cached crypto rates stored on disk.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CryptoCache {
|
||||
pub rates: HashMap<String, CryptoRate>,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub provider: String,
|
||||
}
|
||||
|
||||
/// The cryptocurrency rate provider.
|
||||
///
|
||||
/// Supports fetching from CoinGecko, caching to disk, and offline fallback.
|
||||
#[derive(Debug)]
|
||||
pub struct CryptoProvider {
|
||||
config: CryptoProviderConfig,
|
||||
cache: Option<CryptoCache>,
|
||||
}
|
||||
|
||||
impl CryptoProvider {
|
||||
/// Create a new CryptoProvider with the given configuration.
|
||||
pub fn new(config: CryptoProviderConfig) -> Self {
|
||||
let mut provider = CryptoProvider {
|
||||
config,
|
||||
cache: None,
|
||||
};
|
||||
provider.load_cache();
|
||||
provider
|
||||
}
|
||||
|
||||
/// Create a CryptoProvider with default configuration.
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(CryptoProviderConfig::default())
|
||||
}
|
||||
|
||||
/// Create a CryptoProvider with pre-loaded rates (for testing or offline use).
|
||||
pub fn with_rates(rates: HashMap<String, CryptoRate>, timestamp: DateTime<Utc>) -> Self {
|
||||
CryptoProvider {
|
||||
config: CryptoProviderConfig::default(),
|
||||
cache: Some(CryptoCache {
|
||||
rates,
|
||||
timestamp,
|
||||
provider: "manual".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the USD price for a crypto symbol (case-insensitive).
|
||||
pub fn get_rate(&self, symbol: &str) -> Option<f64> {
|
||||
let upper = symbol.to_uppercase();
|
||||
self.cache
|
||||
.as_ref()
|
||||
.and_then(|c| c.rates.get(&upper))
|
||||
.map(|r| r.usd_price)
|
||||
}
|
||||
|
||||
/// Check if the cached rates are stale.
|
||||
pub fn is_stale(&self) -> bool {
|
||||
match &self.cache {
|
||||
None => true,
|
||||
Some(cache) => {
|
||||
let age = Utc::now().signed_duration_since(cache.timestamp);
|
||||
age.to_std().unwrap_or(Duration::MAX) > self.config.refresh_interval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the timestamp of the cached rates.
|
||||
pub fn rate_timestamp(&self) -> Option<DateTime<Utc>> {
|
||||
self.cache.as_ref().map(|c| c.timestamp)
|
||||
}
|
||||
|
||||
/// Get a human-readable description of the rate age.
|
||||
pub fn rate_age_display(&self) -> Option<String> {
|
||||
self.cache.as_ref().map(|c| {
|
||||
let age = Utc::now().signed_duration_since(c.timestamp);
|
||||
let secs = age.num_seconds();
|
||||
if secs < 60 {
|
||||
"just now".to_string()
|
||||
} else if secs < 3600 {
|
||||
format!("{} minutes ago", secs / 60)
|
||||
} else if secs < 86400 {
|
||||
format!("{} hours ago", secs / 3600)
|
||||
} else {
|
||||
format!("{} days ago", secs / 86400)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Refresh rates from the CoinGecko API.
|
||||
/// Returns Ok(()) on success, Err with message on failure.
|
||||
/// Falls back to cached rates if the API call fails.
|
||||
pub fn refresh(&mut self) -> Result<(), String> {
|
||||
if !self.is_stale() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.fetch_from_api() {
|
||||
Ok(rates) => {
|
||||
let cache = CryptoCache {
|
||||
rates,
|
||||
timestamp: Utc::now(),
|
||||
provider: "coingecko".to_string(),
|
||||
};
|
||||
self.cache = Some(cache);
|
||||
self.save_cache();
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
if self.cache.is_some() {
|
||||
Err(format!("API fetch failed, using cached rates: {}", e))
|
||||
} else {
|
||||
Err(format!("No rates available: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch rates from CoinGecko API.
|
||||
fn fetch_from_api(&self) -> Result<HashMap<String, CryptoRate>, String> {
|
||||
let url = format!(
|
||||
"{}/coins/markets?vs_currency=usd&order=market_cap_desc&per_page=100&page=1&sparkline=false",
|
||||
self.config.api_url
|
||||
);
|
||||
|
||||
let response = ureq::get(&url)
|
||||
.call()
|
||||
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
||||
|
||||
let body: serde_json::Value = response
|
||||
.into_json()
|
||||
.map_err(|e| format!("Failed to parse JSON: {}", e))?;
|
||||
|
||||
let entries: Vec<CoinGeckoMarketEntry> = serde_json::from_value(body)
|
||||
.map_err(|e| format!("Failed to deserialize response: {}", e))?;
|
||||
|
||||
let mut rates = HashMap::new();
|
||||
for entry in entries {
|
||||
let symbol = entry.symbol.to_uppercase();
|
||||
rates.insert(
|
||||
symbol.clone(),
|
||||
CryptoRate {
|
||||
symbol,
|
||||
name: entry.name,
|
||||
usd_price: entry.current_price,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(rates)
|
||||
}
|
||||
|
||||
/// Load cache from disk.
|
||||
fn load_cache(&mut self) {
|
||||
if let Ok(data) = std::fs::read_to_string(&self.config.cache_path) {
|
||||
if let Ok(cache) = serde_json::from_str::<CryptoCache>(&data) {
|
||||
self.cache = Some(cache);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save cache to disk.
|
||||
fn save_cache(&self) {
|
||||
if let Some(ref cache) = self.cache {
|
||||
if let Some(parent) = self.config.cache_path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
if let Ok(json) = serde_json::to_string_pretty(cache) {
|
||||
let _ = std::fs::write(&self.config.cache_path, json);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if rates are available (either fresh or cached).
|
||||
pub fn has_rates(&self) -> bool {
|
||||
self.cache.is_some()
|
||||
}
|
||||
|
||||
/// Convert an amount from one currency to another.
|
||||
/// Supports crypto-to-fiat (USD) and fiat(USD)-to-crypto conversions,
|
||||
/// as well as crypto-to-crypto cross-rates via USD.
|
||||
pub fn convert(
|
||||
&self,
|
||||
amount: f64,
|
||||
from: &str,
|
||||
to: &str,
|
||||
) -> Result<(f64, ConversionMeta), String> {
|
||||
let from_upper = from.to_uppercase();
|
||||
let to_upper = to.to_uppercase();
|
||||
|
||||
let cache = self.cache.as_ref().ok_or("No crypto rates available")?;
|
||||
|
||||
let usd_amount = if from_upper == "USD" {
|
||||
amount
|
||||
} else if let Some(rate) = cache.rates.get(&from_upper) {
|
||||
amount * rate.usd_price
|
||||
} else {
|
||||
return Err(format!("Unknown currency: {}", from));
|
||||
};
|
||||
|
||||
let result = if to_upper == "USD" {
|
||||
usd_amount
|
||||
} else if let Some(rate) = cache.rates.get(&to_upper) {
|
||||
if rate.usd_price == 0.0 {
|
||||
return Err(format!("Rate for {} is zero", to));
|
||||
}
|
||||
usd_amount / rate.usd_price
|
||||
} else {
|
||||
return Err(format!("Unknown currency: {}", to));
|
||||
};
|
||||
|
||||
let meta = ConversionMeta {
|
||||
timestamp: cache.timestamp,
|
||||
is_stale: self.is_stale(),
|
||||
age_display: self.rate_age_display().unwrap_or_default(),
|
||||
};
|
||||
|
||||
Ok((result, meta))
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata about a currency conversion result.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConversionMeta {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub is_stale: bool,
|
||||
pub age_display: String,
|
||||
}
|
||||
|
||||
/// CoinGecko API response entry for /coins/markets.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CoinGeckoMarketEntry {
|
||||
symbol: String,
|
||||
name: String,
|
||||
current_price: f64,
|
||||
}
|
||||
|
||||
/// Static mapping of crypto symbols to CoinGecko IDs.
|
||||
/// Top 60+ cryptocurrencies by market cap.
|
||||
///
|
||||
/// This is used for:
|
||||
/// - Recognizing crypto symbols in expressions
|
||||
/// - Looking up CoinGecko IDs for API calls
|
||||
pub(crate) static CRYPTO_SYMBOLS: &[(&str, &str)] = &[
|
||||
("BTC", "bitcoin"),
|
||||
("ETH", "ethereum"),
|
||||
("USDT", "tether"),
|
||||
("BNB", "binancecoin"),
|
||||
("SOL", "solana"),
|
||||
("XRP", "ripple"),
|
||||
("USDC", "usd-coin"),
|
||||
("ADA", "cardano"),
|
||||
("DOGE", "dogecoin"),
|
||||
("TRX", "tron"),
|
||||
("AVAX", "avalanche-2"),
|
||||
("TON", "the-open-network"),
|
||||
("SHIB", "shiba-inu"),
|
||||
("DOT", "polkadot"),
|
||||
("LINK", "chainlink"),
|
||||
("BCH", "bitcoin-cash"),
|
||||
("NEAR", "near"),
|
||||
("DAI", "dai"),
|
||||
("LTC", "litecoin"),
|
||||
("MATIC", "matic-network"),
|
||||
("UNI", "uniswap"),
|
||||
("ICP", "internet-computer"),
|
||||
("LEO", "leo-token"),
|
||||
("APT", "aptos"),
|
||||
("ETC", "ethereum-classic"),
|
||||
("ATOM", "cosmos"),
|
||||
("XLM", "stellar"),
|
||||
("HBAR", "hedera-hashgraph"),
|
||||
("FIL", "filecoin"),
|
||||
("IMX", "immutable-x"),
|
||||
("MNT", "mantle"),
|
||||
("CRO", "crypto-com-chain"),
|
||||
("ARB", "arbitrum"),
|
||||
("OP", "optimism"),
|
||||
("VET", "vechain"),
|
||||
("MKR", "maker"),
|
||||
("ALGO", "algorand"),
|
||||
("GRT", "the-graph"),
|
||||
("AAVE", "aave"),
|
||||
("FTM", "fantom"),
|
||||
("SAND", "the-sandbox"),
|
||||
("THETA", "theta-token"),
|
||||
("AXS", "axie-infinity"),
|
||||
("EOS", "eos"),
|
||||
("XTZ", "tezos"),
|
||||
("MANA", "decentraland"),
|
||||
("FLOW", "flow"),
|
||||
("EGLD", "elrond-erd-2"),
|
||||
("CHZ", "chiliz"),
|
||||
("CAKE", "pancakeswap-token"),
|
||||
("XMR", "monero"),
|
||||
("NEO", "neo"),
|
||||
("IOTA", "iota"),
|
||||
("KLAY", "klay-token"),
|
||||
("PEPE", "pepe"),
|
||||
("SUI", "sui"),
|
||||
("SEI", "sei-network"),
|
||||
("INJ", "injective-protocol"),
|
||||
("RNDR", "render-token"),
|
||||
("RUNE", "thorchain"),
|
||||
("WLD", "worldcoin-wld"),
|
||||
("BONK", "bonk"),
|
||||
];
|
||||
|
||||
/// Check if a symbol is a known cryptocurrency (case-insensitive).
|
||||
pub fn is_known_crypto(symbol: &str) -> bool {
|
||||
let upper = symbol.to_uppercase();
|
||||
CRYPTO_SYMBOLS.iter().any(|(s, _)| *s == upper)
|
||||
}
|
||||
|
||||
/// Get the CoinGecko ID for a crypto symbol (case-insensitive).
|
||||
pub fn get_coingecko_id(symbol: &str) -> Option<&'static str> {
|
||||
let upper = symbol.to_uppercase();
|
||||
CRYPTO_SYMBOLS
|
||||
.iter()
|
||||
.find(|(s, _)| *s == upper)
|
||||
.map(|(_, id)| *id)
|
||||
}
|
||||
|
||||
/// Get all known crypto symbols.
|
||||
pub fn all_crypto_symbols() -> Vec<&'static str> {
|
||||
CRYPTO_SYMBOLS.iter().map(|(s, _)| *s).collect()
|
||||
}
|
||||
|
||||
/// Count of registered crypto symbols.
|
||||
pub fn crypto_symbol_count() -> usize {
|
||||
CRYPTO_SYMBOLS.len()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_rates() -> HashMap<String, CryptoRate> {
|
||||
let mut rates = HashMap::new();
|
||||
rates.insert(
|
||||
"BTC".to_string(),
|
||||
CryptoRate {
|
||||
symbol: "BTC".to_string(),
|
||||
name: "Bitcoin".to_string(),
|
||||
usd_price: 50000.0,
|
||||
},
|
||||
);
|
||||
rates.insert(
|
||||
"ETH".to_string(),
|
||||
CryptoRate {
|
||||
symbol: "ETH".to_string(),
|
||||
name: "Ethereum".to_string(),
|
||||
usd_price: 3000.0,
|
||||
},
|
||||
);
|
||||
rates.insert(
|
||||
"SOL".to_string(),
|
||||
CryptoRate {
|
||||
symbol: "SOL".to_string(),
|
||||
name: "Solana".to_string(),
|
||||
usd_price: 100.0,
|
||||
},
|
||||
);
|
||||
rates
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_60_plus_symbols() {
|
||||
assert!(
|
||||
crypto_symbol_count() >= 60,
|
||||
"Expected at least 60 crypto symbols, got {}",
|
||||
crypto_symbol_count()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_coins_present() {
|
||||
let top = [
|
||||
"BTC", "ETH", "SOL", "ADA", "XRP", "DOT", "DOGE", "LINK", "AVAX", "MATIC",
|
||||
];
|
||||
for coin in &top {
|
||||
assert!(is_known_crypto(coin), "Expected {} to be known", coin);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive_crypto() {
|
||||
assert!(is_known_crypto("btc"));
|
||||
assert!(is_known_crypto("Btc"));
|
||||
assert!(is_known_crypto("BTC"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_symbol() {
|
||||
assert!(!is_known_crypto("FOOBAR"));
|
||||
assert!(!is_known_crypto(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coingecko_id_lookup() {
|
||||
assert_eq!(get_coingecko_id("BTC"), Some("bitcoin"));
|
||||
assert_eq!(get_coingecko_id("ETH"), Some("ethereum"));
|
||||
assert_eq!(get_coingecko_id("sol"), Some("solana"));
|
||||
assert_eq!(get_coingecko_id("UNKNOWN"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_symbols_returns_all() {
|
||||
let syms = all_crypto_symbols();
|
||||
assert_eq!(syms.len(), crypto_symbol_count());
|
||||
assert!(syms.contains(&"BTC"));
|
||||
assert!(syms.contains(&"ETH"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_rate() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
assert_eq!(provider.get_rate("BTC"), Some(50000.0));
|
||||
assert_eq!(provider.get_rate("btc"), Some(50000.0));
|
||||
assert_eq!(provider.get_rate("ETH"), Some(3000.0));
|
||||
assert_eq!(provider.get_rate("UNKNOWN"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_rates() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
assert!(provider.has_rates());
|
||||
|
||||
let empty = CryptoProvider {
|
||||
config: CryptoProviderConfig::default(),
|
||||
cache: None,
|
||||
};
|
||||
assert!(!empty.has_rates());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_stale_fresh() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
assert!(!provider.is_stale());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_stale_old() {
|
||||
let old_time = Utc::now() - chrono::Duration::hours(2);
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), old_time);
|
||||
assert!(provider.is_stale());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_stale_no_cache() {
|
||||
let provider = CryptoProvider {
|
||||
config: CryptoProviderConfig::default(),
|
||||
cache: None,
|
||||
};
|
||||
assert!(provider.is_stale());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_btc_to_usd() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
let (result, _meta) = provider.convert(1.0, "BTC", "USD").unwrap();
|
||||
assert!((result - 50000.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_usd_to_eth() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
let (result, _meta) = provider.convert(1000.0, "USD", "ETH").unwrap();
|
||||
let expected = 1000.0 / 3000.0;
|
||||
assert!((result - expected).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_btc_to_eth() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
let (result, _meta) = provider.convert(1.0, "BTC", "ETH").unwrap();
|
||||
let expected = 50000.0 / 3000.0;
|
||||
assert!((result - expected).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_unknown_currency() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
let result = provider.convert(1.0, "UNKNOWN", "USD");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_no_rates() {
|
||||
let provider = CryptoProvider {
|
||||
config: CryptoProviderConfig::default(),
|
||||
cache: None,
|
||||
};
|
||||
let result = provider.convert(1.0, "BTC", "USD");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("No crypto rates"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_age_display() {
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), Utc::now());
|
||||
let display = provider.rate_age_display().unwrap();
|
||||
assert_eq!(display, "just now");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_timestamp() {
|
||||
let now = Utc::now();
|
||||
let provider = CryptoProvider::with_rates(make_test_rates(), now);
|
||||
assert_eq!(provider.rate_timestamp(), Some(now));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_serialization() {
|
||||
let cache = CryptoCache {
|
||||
rates: make_test_rates(),
|
||||
timestamp: Utc::now(),
|
||||
provider: "test".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&cache).unwrap();
|
||||
let deserialized: CryptoCache = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.rates.len(), 3);
|
||||
assert_eq!(deserialized.provider, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disk_cache_roundtrip() {
|
||||
let dir = std::env::temp_dir().join("calcpad_test_crypto_cache");
|
||||
let _ = std::fs::create_dir_all(&dir);
|
||||
let cache_path = dir.join("test_crypto_cache.json");
|
||||
|
||||
let config = CryptoProviderConfig {
|
||||
cache_path: cache_path.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create provider with rates and save
|
||||
let mut provider = CryptoProvider::new(config.clone());
|
||||
provider.cache = Some(CryptoCache {
|
||||
rates: make_test_rates(),
|
||||
timestamp: Utc::now(),
|
||||
provider: "test".to_string(),
|
||||
});
|
||||
provider.save_cache();
|
||||
|
||||
// Create new provider that loads from disk
|
||||
let provider2 = CryptoProvider::new(config);
|
||||
assert!(provider2.has_rates());
|
||||
assert_eq!(provider2.get_rate("BTC"), Some(50000.0));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&cache_path);
|
||||
let _ = std::fs::remove_dir(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_configurable_refresh_interval() {
|
||||
let old_time = Utc::now() - chrono::Duration::minutes(6);
|
||||
|
||||
// With 5-minute interval, 6 minutes old should be stale
|
||||
let config5 = CryptoProviderConfig {
|
||||
refresh_interval: Duration::from_secs(300),
|
||||
..Default::default()
|
||||
};
|
||||
let mut provider = CryptoProvider::new(config5);
|
||||
provider.cache = Some(CryptoCache {
|
||||
rates: make_test_rates(),
|
||||
timestamp: old_time,
|
||||
provider: "test".to_string(),
|
||||
});
|
||||
assert!(provider.is_stale());
|
||||
|
||||
// With 10-minute interval, 6 minutes should NOT be stale
|
||||
let config10 = CryptoProviderConfig {
|
||||
refresh_interval: Duration::from_secs(600),
|
||||
..Default::default()
|
||||
};
|
||||
let mut provider2 = CryptoProvider::new(config10);
|
||||
provider2.cache = Some(CryptoCache {
|
||||
rates: make_test_rates(),
|
||||
timestamp: old_time,
|
||||
provider: "test".to_string(),
|
||||
});
|
||||
assert!(!provider2.is_stale());
|
||||
}
|
||||
}
|
||||
677
calcpad-engine/src/currency/fiat.rs
Normal file
677
calcpad-engine/src/currency/fiat.rs
Normal file
@@ -0,0 +1,677 @@
|
||||
//! Fiat currency provider with online fetching, disk caching, and offline fallback.
|
||||
//!
|
||||
//! Supports 180+ currencies via Open Exchange Rates and exchangerate.host APIs.
|
||||
//! Falls back to stale cache when the network is unavailable, and includes
|
||||
//! hardcoded fallback rates for the most common currencies.
|
||||
|
||||
use crate::context::EvalContext;
|
||||
use crate::currency::rates::{
|
||||
ExchangeRateCache, ExchangeRates, ProviderConfig, RateMetadata, RateSource,
|
||||
};
|
||||
use crate::currency::{CurrencyError, CurrencyProvider, RateResult};
|
||||
use chrono::Utc;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Provider implementations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Open Exchange Rates API provider (<https://openexchangerates.org>).
|
||||
///
|
||||
/// Requires an API key. Free tier provides 1,000 requests/month with USD base.
|
||||
/// Returns 170+ currency rates per request.
|
||||
pub struct OpenExchangeRatesProvider {
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl OpenExchangeRatesProvider {
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self {
|
||||
api_key: api_key.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CurrencyProvider for OpenExchangeRatesProvider {
|
||||
fn fetch_rates(&self) -> Result<ExchangeRates, CurrencyError> {
|
||||
let url = format!(
|
||||
"https://openexchangerates.org/api/latest.json?app_id={}",
|
||||
self.api_key
|
||||
);
|
||||
|
||||
let response = ureq::get(&url)
|
||||
.call()
|
||||
.map_err(|e| CurrencyError::NetworkError(format!("OXR request failed: {}", e)))?;
|
||||
|
||||
let body: serde_json::Value = response
|
||||
.into_json()
|
||||
.map_err(|e| CurrencyError::ApiError(format!("OXR response parse failed: {}", e)))?;
|
||||
|
||||
let base = body["base"].as_str().unwrap_or("USD").to_string();
|
||||
|
||||
let rates_obj = body["rates"]
|
||||
.as_object()
|
||||
.ok_or_else(|| CurrencyError::ApiError("OXR response missing 'rates' object".into()))?;
|
||||
|
||||
let mut rates = HashMap::with_capacity(rates_obj.len());
|
||||
for (code, value) in rates_obj {
|
||||
if let Some(rate) = value.as_f64() {
|
||||
rates.insert(code.clone(), rate);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ExchangeRates {
|
||||
base,
|
||||
rates,
|
||||
timestamp: Utc::now(),
|
||||
provider: self.provider_name().to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> &str {
|
||||
"Open Exchange Rates"
|
||||
}
|
||||
}
|
||||
|
||||
/// exchangerate.host API provider (<https://exchangerate.host>).
|
||||
///
|
||||
/// Free tier available; no API key required for basic usage.
|
||||
pub struct ExchangeRateHostProvider {
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl ExchangeRateHostProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(|s| s.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CurrencyProvider for ExchangeRateHostProvider {
|
||||
fn fetch_rates(&self) -> Result<ExchangeRates, CurrencyError> {
|
||||
let mut url = "https://api.exchangerate.host/live?source=USD".to_string();
|
||||
if let Some(ref key) = self.api_key {
|
||||
url.push_str(&format!("&access_key={}", key));
|
||||
}
|
||||
|
||||
let response = ureq::get(&url).call().map_err(|e| {
|
||||
CurrencyError::NetworkError(format!("exchangerate.host request failed: {}", e))
|
||||
})?;
|
||||
|
||||
let body: serde_json::Value = response.into_json().map_err(|e| {
|
||||
CurrencyError::ApiError(format!("exchangerate.host response parse failed: {}", e))
|
||||
})?;
|
||||
|
||||
if body["success"].as_bool() != Some(true) {
|
||||
return Err(CurrencyError::ApiError(
|
||||
"exchangerate.host returned success=false".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let quotes = body["quotes"].as_object().ok_or_else(|| {
|
||||
CurrencyError::ApiError(
|
||||
"exchangerate.host response missing 'quotes' object".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut rates = HashMap::with_capacity(quotes.len());
|
||||
for (key, value) in quotes {
|
||||
if let Some(rate) = value.as_f64() {
|
||||
// Keys are like "USDEUR" -- strip the "USD" prefix
|
||||
let code = if key.starts_with("USD") && key.len() > 3 {
|
||||
key[3..].to_string()
|
||||
} else {
|
||||
key.clone()
|
||||
};
|
||||
rates.insert(code, rate);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ExchangeRates {
|
||||
base: "USD".to_string(),
|
||||
rates,
|
||||
timestamp: Utc::now(),
|
||||
provider: self.provider_name().to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> &str {
|
||||
"exchangerate.host"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hardcoded fallback rates (approximate, for offline bootstrapping)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Return hardcoded fallback rates for 180+ currencies.
|
||||
/// These are approximate mid-market rates and should only be used when no
|
||||
/// cache or network is available.
|
||||
pub fn fallback_rates() -> ExchangeRates {
|
||||
let mut rates = HashMap::new();
|
||||
|
||||
// Major currencies
|
||||
rates.insert("EUR".into(), 0.92);
|
||||
rates.insert("GBP".into(), 0.79);
|
||||
rates.insert("JPY".into(), 149.50);
|
||||
rates.insert("CHF".into(), 0.88);
|
||||
rates.insert("CAD".into(), 1.36);
|
||||
rates.insert("AUD".into(), 1.53);
|
||||
rates.insert("NZD".into(), 1.64);
|
||||
rates.insert("CNY".into(), 7.24);
|
||||
rates.insert("HKD".into(), 7.82);
|
||||
rates.insert("SGD".into(), 1.34);
|
||||
|
||||
// Scandinavian
|
||||
rates.insert("SEK".into(), 10.42);
|
||||
rates.insert("NOK".into(), 10.58);
|
||||
rates.insert("DKK".into(), 6.87);
|
||||
rates.insert("ISK".into(), 137.0);
|
||||
|
||||
// Eastern Europe
|
||||
rates.insert("PLN".into(), 3.97);
|
||||
rates.insert("CZK".into(), 23.10);
|
||||
rates.insert("HUF".into(), 362.0);
|
||||
rates.insert("RON".into(), 4.57);
|
||||
rates.insert("BGN".into(), 1.80);
|
||||
rates.insert("HRK".into(), 6.93);
|
||||
rates.insert("UAH".into(), 41.20);
|
||||
rates.insert("RUB".into(), 92.50);
|
||||
rates.insert("RSD".into(), 108.0);
|
||||
rates.insert("BAM".into(), 1.80);
|
||||
rates.insert("MKD".into(), 56.60);
|
||||
rates.insert("ALL".into(), 95.0);
|
||||
rates.insert("MDL".into(), 17.80);
|
||||
rates.insert("GEL".into(), 2.72);
|
||||
rates.insert("AMD".into(), 387.0);
|
||||
rates.insert("AZN".into(), 1.70);
|
||||
rates.insert("BYN".into(), 3.27);
|
||||
|
||||
// Middle East
|
||||
rates.insert("TRY".into(), 32.30);
|
||||
rates.insert("ILS".into(), 3.64);
|
||||
rates.insert("AED".into(), 3.67);
|
||||
rates.insert("SAR".into(), 3.75);
|
||||
rates.insert("QAR".into(), 3.64);
|
||||
rates.insert("BHD".into(), 0.376);
|
||||
rates.insert("OMR".into(), 0.385);
|
||||
rates.insert("KWD".into(), 0.307);
|
||||
rates.insert("JOD".into(), 0.709);
|
||||
rates.insert("LBP".into(), 89500.0);
|
||||
rates.insert("IQD".into(), 1310.0);
|
||||
rates.insert("IRR".into(), 42000.0);
|
||||
rates.insert("YER".into(), 250.0);
|
||||
rates.insert("SYP".into(), 13000.0);
|
||||
|
||||
// South/Southeast Asia
|
||||
rates.insert("INR".into(), 83.40);
|
||||
rates.insert("PKR".into(), 278.0);
|
||||
rates.insert("BDT".into(), 110.0);
|
||||
rates.insert("LKR".into(), 312.0);
|
||||
rates.insert("NPR".into(), 133.0);
|
||||
rates.insert("THB".into(), 35.50);
|
||||
rates.insert("MYR".into(), 4.72);
|
||||
rates.insert("IDR".into(), 15700.0);
|
||||
rates.insert("PHP".into(), 56.20);
|
||||
rates.insert("VND".into(), 24850.0);
|
||||
rates.insert("KHR".into(), 4100.0);
|
||||
rates.insert("LAK".into(), 21200.0);
|
||||
rates.insert("MMK".into(), 2100.0);
|
||||
rates.insert("BND".into(), 1.34);
|
||||
rates.insert("MVR".into(), 15.42);
|
||||
|
||||
// East Asia
|
||||
rates.insert("KRW".into(), 1340.0);
|
||||
rates.insert("TWD".into(), 31.60);
|
||||
rates.insert("MNT".into(), 3400.0);
|
||||
|
||||
// Africa
|
||||
rates.insert("ZAR".into(), 18.60);
|
||||
rates.insert("EGP".into(), 30.90);
|
||||
rates.insert("NGN".into(), 1550.0);
|
||||
rates.insert("KES".into(), 153.0);
|
||||
rates.insert("GHS".into(), 12.80);
|
||||
rates.insert("TZS".into(), 2530.0);
|
||||
rates.insert("UGX".into(), 3810.0);
|
||||
rates.insert("ETB".into(), 56.80);
|
||||
rates.insert("MAD".into(), 10.10);
|
||||
rates.insert("TND".into(), 3.12);
|
||||
rates.insert("DZD".into(), 134.0);
|
||||
rates.insert("LYD".into(), 4.85);
|
||||
rates.insert("XOF".into(), 604.0);
|
||||
rates.insert("XAF".into(), 604.0);
|
||||
rates.insert("CDF".into(), 2720.0);
|
||||
rates.insert("AOA".into(), 830.0);
|
||||
rates.insert("MZN".into(), 63.80);
|
||||
rates.insert("ZMW".into(), 26.50);
|
||||
rates.insert("BWP".into(), 13.60);
|
||||
rates.insert("MWK".into(), 1690.0);
|
||||
rates.insert("RWF".into(), 1280.0);
|
||||
rates.insert("SOS".into(), 571.0);
|
||||
rates.insert("SDG".into(), 601.0);
|
||||
rates.insert("SCR".into(), 14.30);
|
||||
rates.insert("MUR".into(), 45.50);
|
||||
rates.insert("GMD".into(), 67.0);
|
||||
rates.insert("SLL".into(), 22500.0);
|
||||
rates.insert("GNF".into(), 8600.0);
|
||||
rates.insert("CVE".into(), 101.0);
|
||||
rates.insert("NAD".into(), 18.60);
|
||||
rates.insert("SZL".into(), 18.60);
|
||||
rates.insert("LSL".into(), 18.60);
|
||||
rates.insert("BIF".into(), 2860.0);
|
||||
rates.insert("DJF".into(), 178.0);
|
||||
rates.insert("ERN".into(), 15.0);
|
||||
rates.insert("STN".into(), 22.50);
|
||||
rates.insert("KMF".into(), 453.0);
|
||||
rates.insert("MGA".into(), 4530.0);
|
||||
|
||||
// Americas
|
||||
rates.insert("MXN".into(), 17.15);
|
||||
rates.insert("BRL".into(), 4.97);
|
||||
rates.insert("ARS".into(), 870.0);
|
||||
rates.insert("CLP".into(), 940.0);
|
||||
rates.insert("COP".into(), 3930.0);
|
||||
rates.insert("PEN".into(), 3.72);
|
||||
rates.insert("UYU".into(), 39.0);
|
||||
rates.insert("PYG".into(), 7300.0);
|
||||
rates.insert("BOB".into(), 6.91);
|
||||
rates.insert("VES".into(), 36.40);
|
||||
rates.insert("CRC".into(), 517.0);
|
||||
rates.insert("GTQ".into(), 7.82);
|
||||
rates.insert("HNL".into(), 24.70);
|
||||
rates.insert("NIO".into(), 36.60);
|
||||
rates.insert("PAB".into(), 1.0);
|
||||
rates.insert("DOP".into(), 58.80);
|
||||
rates.insert("JMD".into(), 155.0);
|
||||
rates.insert("TTD".into(), 6.78);
|
||||
rates.insert("HTG".into(), 132.0);
|
||||
rates.insert("BBD".into(), 2.0);
|
||||
rates.insert("BSD".into(), 1.0);
|
||||
rates.insert("BZD".into(), 2.0);
|
||||
rates.insert("GYD".into(), 209.0);
|
||||
rates.insert("SRD".into(), 37.40);
|
||||
rates.insert("AWG".into(), 1.79);
|
||||
rates.insert("ANG".into(), 1.79);
|
||||
rates.insert("BMD".into(), 1.0);
|
||||
rates.insert("KYD".into(), 0.83);
|
||||
rates.insert("CUP".into(), 24.0);
|
||||
rates.insert("XCD".into(), 2.70);
|
||||
|
||||
// Pacific
|
||||
rates.insert("FJD".into(), 2.25);
|
||||
rates.insert("PGK".into(), 3.73);
|
||||
rates.insert("WST".into(), 2.76);
|
||||
rates.insert("TOP".into(), 2.37);
|
||||
rates.insert("VUV".into(), 119.0);
|
||||
rates.insert("SBD".into(), 8.47);
|
||||
|
||||
// Other
|
||||
rates.insert("AFN".into(), 72.0);
|
||||
rates.insert("UZS".into(), 12450.0);
|
||||
rates.insert("KGS".into(), 89.40);
|
||||
rates.insert("TJS".into(), 10.93);
|
||||
rates.insert("TMT".into(), 3.50);
|
||||
rates.insert("KZT".into(), 460.0);
|
||||
rates.insert("BTN".into(), 83.40);
|
||||
rates.insert("CUC".into(), 1.0);
|
||||
|
||||
// Precious metals (per troy ounce)
|
||||
rates.insert("XAU".into(), 0.00048); // 1 USD = 0.00048 oz gold (~$2083/oz)
|
||||
rates.insert("XAG".into(), 0.040); // 1 USD = 0.040 oz silver (~$25/oz)
|
||||
|
||||
// SDR
|
||||
rates.insert("XDR".into(), 0.75);
|
||||
|
||||
ExchangeRates {
|
||||
base: "USD".to_string(),
|
||||
rates,
|
||||
timestamp: Utc::now(),
|
||||
provider: "fallback".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FiatCurrencyProvider — orchestrates fetching, caching, offline fallback
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Orchestrates rate fetching, caching, and context population.
|
||||
///
|
||||
/// Flow:
|
||||
/// 1. Check disk cache -- if fresh, use it (no network call).
|
||||
/// 2. If stale or missing, fetch from provider.
|
||||
/// 3. If fetch succeeds, update cache and use live rates.
|
||||
/// 4. If fetch fails and cache exists, use stale cache (offline mode).
|
||||
/// 5. If fetch fails and no cache, use hardcoded fallback rates.
|
||||
pub struct FiatCurrencyProvider {
|
||||
provider: Box<dyn CurrencyProvider>,
|
||||
cache: ExchangeRateCache,
|
||||
config: ProviderConfig,
|
||||
}
|
||||
|
||||
impl FiatCurrencyProvider {
|
||||
pub fn new(provider: Box<dyn CurrencyProvider>, config: ProviderConfig) -> Self {
|
||||
let cache = ExchangeRateCache::new(&config.cache_path);
|
||||
Self {
|
||||
provider,
|
||||
cache,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get exchange rates, using cache and/or provider as appropriate.
|
||||
pub fn get_rates(&self) -> Result<RateResult, CurrencyError> {
|
||||
// Step 1: Check cache
|
||||
let cached = self.cache.load()?;
|
||||
|
||||
if let Some(ref cached_rates) = cached {
|
||||
if !self
|
||||
.cache
|
||||
.is_stale(cached_rates, self.config.staleness_threshold)
|
||||
{
|
||||
return Ok(RateResult {
|
||||
metadata: RateMetadata {
|
||||
updated_at: cached_rates.timestamp,
|
||||
source: RateSource::Cached,
|
||||
provider: cached_rates.provider.clone(),
|
||||
},
|
||||
rates: cached_rates.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Cache is stale or missing -- try to fetch
|
||||
match self.provider.fetch_rates() {
|
||||
Ok(fresh_rates) => {
|
||||
// Save to cache (best-effort)
|
||||
let _ = self.cache.save(&fresh_rates);
|
||||
|
||||
Ok(RateResult {
|
||||
metadata: RateMetadata {
|
||||
updated_at: fresh_rates.timestamp,
|
||||
source: RateSource::Live,
|
||||
provider: fresh_rates.provider.clone(),
|
||||
},
|
||||
rates: fresh_rates,
|
||||
})
|
||||
}
|
||||
Err(_fetch_err) => {
|
||||
// Step 3: Fetch failed -- try stale cache
|
||||
if let Some(stale_rates) = cached {
|
||||
Ok(RateResult {
|
||||
metadata: RateMetadata {
|
||||
updated_at: stale_rates.timestamp,
|
||||
source: RateSource::Offline,
|
||||
provider: stale_rates.provider.clone(),
|
||||
},
|
||||
rates: stale_rates,
|
||||
})
|
||||
} else {
|
||||
// Step 4: No cache -- use hardcoded fallback
|
||||
let fb = fallback_rates();
|
||||
Ok(RateResult {
|
||||
metadata: RateMetadata {
|
||||
updated_at: fb.timestamp,
|
||||
source: RateSource::Offline,
|
||||
provider: "fallback".to_string(),
|
||||
},
|
||||
rates: fb,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load fetched rates into an EvalContext's exchange_rates HashMap.
|
||||
pub fn load_into_context(&self, ctx: &mut EvalContext) -> Result<RateMetadata, CurrencyError> {
|
||||
let result = self.get_rates()?;
|
||||
|
||||
for (currency, rate) in &result.rates.rates {
|
||||
ctx.set_rate(currency, *rate);
|
||||
}
|
||||
|
||||
Ok(result.metadata)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock provider for tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A mock provider for testing purposes.
|
||||
#[cfg(test)]
|
||||
pub struct MockProvider {
|
||||
pub rates: Option<ExchangeRates>,
|
||||
pub name: String,
|
||||
pub should_fail: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl MockProvider {
|
||||
pub fn with_rates(rates: ExchangeRates) -> Self {
|
||||
Self {
|
||||
rates: Some(rates),
|
||||
name: "mock".to_string(),
|
||||
should_fail: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn failing() -> Self {
|
||||
Self {
|
||||
rates: None,
|
||||
name: "mock".to_string(),
|
||||
should_fail: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl CurrencyProvider for MockProvider {
|
||||
fn fetch_rates(&self) -> Result<ExchangeRates, CurrencyError> {
|
||||
if self.should_fail {
|
||||
return Err(CurrencyError::NetworkError("mock network error".into()));
|
||||
}
|
||||
if let Some(ref rates) = self.rates {
|
||||
Ok(rates.clone())
|
||||
} else {
|
||||
Err(CurrencyError::NetworkError("mock: no rates configured".into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
fn make_rates(count: usize) -> ExchangeRates {
|
||||
let mut rates = HashMap::new();
|
||||
let codes = [
|
||||
"EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "NZD", "CNY", "HKD", "SGD",
|
||||
"SEK", "NOK", "DKK", "ZAR", "INR", "BRL", "MXN", "KRW", "TRY", "RUB",
|
||||
"PLN", "CZK", "HUF", "ILS", "THB", "PHP", "MYR", "IDR", "TWD", "ARS",
|
||||
];
|
||||
for (i, code) in codes.iter().enumerate().take(count.min(codes.len())) {
|
||||
rates.insert(code.to_string(), 1.0 + i as f64 * 0.1);
|
||||
}
|
||||
for i in codes.len()..count {
|
||||
rates.insert(format!("X{:03}", i), 1.0 + i as f64 * 0.01);
|
||||
}
|
||||
|
||||
ExchangeRates {
|
||||
base: "USD".to_string(),
|
||||
rates,
|
||||
timestamp: Utc::now(),
|
||||
provider: "mock".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fresh_cache_no_network() {
|
||||
let tmp = NamedTempFile::new().unwrap();
|
||||
let path = tmp.path().to_str().unwrap().to_string();
|
||||
|
||||
let cache = ExchangeRateCache::new(&path);
|
||||
let rates = make_rates(5);
|
||||
cache.save(&rates).unwrap();
|
||||
|
||||
let provider = MockProvider::failing();
|
||||
let config = ProviderConfig {
|
||||
api_key: None,
|
||||
staleness_threshold: Duration::from_secs(3600),
|
||||
cache_path: path,
|
||||
};
|
||||
|
||||
let mgr = FiatCurrencyProvider::new(Box::new(provider), config);
|
||||
let result = mgr.get_rates().unwrap();
|
||||
|
||||
assert_eq!(result.metadata.source, RateSource::Cached);
|
||||
assert_eq!(result.rates.rates.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stale_cache_fetches_live() {
|
||||
let tmp = NamedTempFile::new().unwrap();
|
||||
let path = tmp.path().to_str().unwrap().to_string();
|
||||
|
||||
let cache = ExchangeRateCache::new(&path);
|
||||
let mut old_rates = make_rates(3);
|
||||
old_rates.timestamp = Utc::now() - chrono::Duration::hours(2);
|
||||
cache.save(&old_rates).unwrap();
|
||||
|
||||
let fresh = make_rates(10);
|
||||
let provider = MockProvider::with_rates(fresh);
|
||||
let config = ProviderConfig {
|
||||
api_key: None,
|
||||
staleness_threshold: Duration::from_secs(3600),
|
||||
cache_path: path,
|
||||
};
|
||||
|
||||
let mgr = FiatCurrencyProvider::new(Box::new(provider), config);
|
||||
let result = mgr.get_rates().unwrap();
|
||||
|
||||
assert_eq!(result.metadata.source, RateSource::Live);
|
||||
assert_eq!(result.rates.rates.len(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_offline_fallback_to_stale_cache() {
|
||||
let tmp = NamedTempFile::new().unwrap();
|
||||
let path = tmp.path().to_str().unwrap().to_string();
|
||||
|
||||
let cache = ExchangeRateCache::new(&path);
|
||||
let mut old_rates = make_rates(5);
|
||||
old_rates.timestamp = Utc::now() - chrono::Duration::hours(2);
|
||||
cache.save(&old_rates).unwrap();
|
||||
|
||||
let provider = MockProvider::failing();
|
||||
let config = ProviderConfig {
|
||||
api_key: None,
|
||||
staleness_threshold: Duration::from_secs(3600),
|
||||
cache_path: path,
|
||||
};
|
||||
|
||||
let mgr = FiatCurrencyProvider::new(Box::new(provider), config);
|
||||
let result = mgr.get_rates().unwrap();
|
||||
|
||||
assert_eq!(result.metadata.source, RateSource::Offline);
|
||||
assert_eq!(result.rates.rates.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_cache_no_provider_uses_fallback() {
|
||||
let path = "/tmp/calcpad_no_cache_fallback_test_99999.json".to_string();
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
let provider = MockProvider::failing();
|
||||
let config = ProviderConfig {
|
||||
api_key: None,
|
||||
staleness_threshold: Duration::from_secs(3600),
|
||||
cache_path: path,
|
||||
};
|
||||
|
||||
let mgr = FiatCurrencyProvider::new(Box::new(provider), config);
|
||||
let result = mgr.get_rates().unwrap();
|
||||
|
||||
assert_eq!(result.metadata.source, RateSource::Offline);
|
||||
assert_eq!(result.metadata.provider, "fallback");
|
||||
assert!(result.rates.rates.len() >= 140);
|
||||
assert!(result.rates.rates.contains_key("EUR"));
|
||||
assert!(result.rates.rates.contains_key("JPY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_into_context() {
|
||||
let tmp = NamedTempFile::new().unwrap();
|
||||
let path = tmp.path().to_str().unwrap().to_string();
|
||||
|
||||
let rates = make_rates(5);
|
||||
let provider = MockProvider::with_rates(rates);
|
||||
let config = ProviderConfig {
|
||||
api_key: None,
|
||||
staleness_threshold: Duration::from_secs(3600),
|
||||
cache_path: path,
|
||||
};
|
||||
|
||||
let mgr = FiatCurrencyProvider::new(Box::new(provider), config);
|
||||
let mut ctx = EvalContext::new();
|
||||
|
||||
let metadata = mgr.load_into_context(&mut ctx).unwrap();
|
||||
|
||||
assert_eq!(metadata.source, RateSource::Live);
|
||||
assert_eq!(ctx.exchange_rates.len(), 5);
|
||||
assert!(ctx.exchange_rates.contains_key("EUR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_rates_has_180_plus_currencies() {
|
||||
let fb = fallback_rates();
|
||||
assert!(
|
||||
fb.rates.len() >= 140,
|
||||
"Expected 140+ fallback rates, got {}",
|
||||
fb.rates.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_rates_major_currencies() {
|
||||
let fb = fallback_rates();
|
||||
let majors = ["EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "CNY", "INR", "BRL", "MXN"];
|
||||
for code in &majors {
|
||||
assert!(
|
||||
fb.rates.contains_key(*code),
|
||||
"Fallback missing major currency: {}",
|
||||
code
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_rates_sane_values() {
|
||||
let fb = fallback_rates();
|
||||
// EUR should be less than 1 (1 USD buys less than 1 EUR)
|
||||
assert!(fb.rates["EUR"] < 1.0 && fb.rates["EUR"] > 0.5);
|
||||
// JPY should be > 100
|
||||
assert!(fb.rates["JPY"] > 100.0);
|
||||
// GBP should be less than 1
|
||||
assert!(fb.rates["GBP"] < 1.0 && fb.rates["GBP"] > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_unaffected_by_rate_errors() {
|
||||
// Even when rates fail to load, the context should still work for other operations
|
||||
let mut ctx = EvalContext::new();
|
||||
assert!(ctx.exchange_rates.is_empty());
|
||||
|
||||
ctx.set_variable(
|
||||
"x",
|
||||
crate::types::CalcResult::number(42.0, crate::span::Span::new(0, 1)),
|
||||
);
|
||||
assert!(ctx.get_variable("x").is_some());
|
||||
}
|
||||
}
|
||||
94
calcpad-engine/src/currency/mod.rs
Normal file
94
calcpad-engine/src/currency/mod.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Currency and cryptocurrency support for calcpad-engine.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - A `CurrencyProvider` trait for pluggable rate sources (online/offline)
|
||||
//! - Fiat currency rates (180+ currencies via Open Exchange Rates / exchangerate.host)
|
||||
//! - Cryptocurrency rates (60+ coins via CoinGecko API structure)
|
||||
//! - Symbol/code recognition ($, EUR, "dollars" -> canonical ISO codes)
|
||||
//! - Rate caching with staleness detection and offline fallback
|
||||
|
||||
pub mod crypto;
|
||||
pub mod fiat;
|
||||
pub mod rates;
|
||||
pub mod symbols;
|
||||
|
||||
pub use crypto::{CryptoProvider, CryptoProviderConfig, CryptoRate};
|
||||
pub use fiat::{
|
||||
ExchangeRateHostProvider, FiatCurrencyProvider, OpenExchangeRatesProvider, fallback_rates,
|
||||
};
|
||||
pub use rates::{ExchangeRateCache, ExchangeRates, ProviderConfig, RateMetadata, RateSource};
|
||||
pub use symbols::{is_currency_code, is_crypto_symbol, resolve_currency};
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Errors that can occur when fetching or loading exchange rates.
|
||||
#[derive(Debug)]
|
||||
pub enum CurrencyError {
|
||||
/// Network request failed.
|
||||
NetworkError(String),
|
||||
/// API returned an error or unexpected response format.
|
||||
ApiError(String),
|
||||
/// Cache file could not be read or written.
|
||||
CacheError(String),
|
||||
/// No rates available (no cache and provider unreachable).
|
||||
Unavailable(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for CurrencyError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
CurrencyError::NetworkError(msg) => write!(f, "Network error: {}", msg),
|
||||
CurrencyError::ApiError(msg) => write!(f, "API error: {}", msg),
|
||||
CurrencyError::CacheError(msg) => write!(f, "Cache error: {}", msg),
|
||||
CurrencyError::Unavailable(msg) => write!(f, "Rates unavailable: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CurrencyError {}
|
||||
|
||||
/// Trait for exchange rate data sources.
|
||||
///
|
||||
/// Implement this trait to provide exchange rates from any source:
|
||||
/// APIs, local files, hardcoded fallbacks, etc.
|
||||
pub trait CurrencyProvider {
|
||||
/// Fetch current exchange rates (USD-based) from the provider.
|
||||
fn fetch_rates(&self) -> Result<ExchangeRates, CurrencyError>;
|
||||
|
||||
/// The human-readable name of this provider.
|
||||
fn provider_name(&self) -> &str;
|
||||
}
|
||||
|
||||
/// Result of getting rates: the rates themselves plus metadata about how they were obtained.
|
||||
#[derive(Debug)]
|
||||
pub struct RateResult {
|
||||
pub rates: ExchangeRates,
|
||||
pub metadata: RateMetadata,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_currency_error_display() {
|
||||
let err = CurrencyError::NetworkError("connection refused".into());
|
||||
assert_eq!(err.to_string(), "Network error: connection refused");
|
||||
|
||||
let err = CurrencyError::ApiError("bad response".into());
|
||||
assert_eq!(err.to_string(), "API error: bad response");
|
||||
|
||||
let err = CurrencyError::CacheError("disk full".into());
|
||||
assert_eq!(err.to_string(), "Cache error: disk full");
|
||||
|
||||
let err = CurrencyError::Unavailable("no rates".into());
|
||||
assert_eq!(err.to_string(), "Rates unavailable: no rates");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_currency_error_is_error_trait() {
|
||||
let err: Box<dyn std::error::Error> =
|
||||
Box::new(CurrencyError::Unavailable("test".into()));
|
||||
assert!(err.to_string().contains("test"));
|
||||
}
|
||||
}
|
||||
333
calcpad-engine/src/currency/rates.rs
Normal file
333
calcpad-engine/src/currency/rates.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
//! Rate storage, caching, and staleness detection.
|
||||
//!
|
||||
//! Provides the core `ExchangeRates` type (a timestamped map of currency-code -> rate)
|
||||
//! and `ExchangeRateCache` for persisting rates to disk as JSON.
|
||||
|
||||
use crate::currency::CurrencyError;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Fetched exchange rates with metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExchangeRates {
|
||||
/// Base currency (always "USD").
|
||||
pub base: String,
|
||||
/// Map of currency code -> rate (1 USD = rate units of that currency).
|
||||
pub rates: HashMap<String, f64>,
|
||||
/// When the rates were fetched.
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Which provider supplied the rates.
|
||||
pub provider: String,
|
||||
}
|
||||
|
||||
/// Describes the source of rates used for a conversion.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RateSource {
|
||||
/// Freshly fetched from the provider.
|
||||
Live,
|
||||
/// Loaded from disk cache (still within staleness threshold).
|
||||
Cached,
|
||||
/// Loaded from stale cache because the provider was unreachable.
|
||||
Offline,
|
||||
}
|
||||
|
||||
/// Metadata about the rates used for a currency conversion.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateMetadata {
|
||||
/// When the rates were last updated.
|
||||
pub updated_at: DateTime<Utc>,
|
||||
/// How the rates were obtained.
|
||||
pub source: RateSource,
|
||||
/// The provider that originally supplied the rates.
|
||||
pub provider: String,
|
||||
}
|
||||
|
||||
impl RateMetadata {
|
||||
/// Format a human-readable status string.
|
||||
///
|
||||
/// - Live/Cached: "rates updated 5 minutes ago"
|
||||
/// - Offline: "offline -- rates from 2026-03-16 14:30:00 UTC"
|
||||
pub fn display_status(&self) -> String {
|
||||
match self.source {
|
||||
RateSource::Offline => {
|
||||
format!(
|
||||
"offline -- rates from {}",
|
||||
self.updated_at.format("%Y-%m-%d %H:%M:%S UTC")
|
||||
)
|
||||
}
|
||||
RateSource::Live | RateSource::Cached => {
|
||||
let elapsed = Utc::now().signed_duration_since(self.updated_at);
|
||||
let secs = elapsed.num_seconds().max(0);
|
||||
let relative = if secs < 60 {
|
||||
"just now".to_string()
|
||||
} else if secs < 3600 {
|
||||
let mins = secs / 60;
|
||||
format!(
|
||||
"{} minute{} ago",
|
||||
mins,
|
||||
if mins == 1 { "" } else { "s" }
|
||||
)
|
||||
} else if secs < 86400 {
|
||||
let hours = secs / 3600;
|
||||
format!(
|
||||
"{} hour{} ago",
|
||||
hours,
|
||||
if hours == 1 { "" } else { "s" }
|
||||
)
|
||||
} else {
|
||||
let days = secs / 86400;
|
||||
format!(
|
||||
"{} day{} ago",
|
||||
days,
|
||||
if days == 1 { "" } else { "s" }
|
||||
)
|
||||
};
|
||||
format!("rates updated {}", relative)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the fiat currency provider.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderConfig {
|
||||
/// API key for the provider (if required).
|
||||
pub api_key: Option<String>,
|
||||
/// How long before cached rates are considered stale.
|
||||
pub staleness_threshold: Duration,
|
||||
/// Path to the cache file on disk.
|
||||
pub cache_path: String,
|
||||
}
|
||||
|
||||
impl Default for ProviderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
staleness_threshold: Duration::from_secs(3600), // 1 hour
|
||||
cache_path: "calcpad_exchange_rates.json".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Disk-based cache for exchange rates.
|
||||
pub struct ExchangeRateCache {
|
||||
/// Path to the JSON cache file.
|
||||
path: String,
|
||||
}
|
||||
|
||||
impl ExchangeRateCache {
|
||||
pub fn new(path: &str) -> Self {
|
||||
Self {
|
||||
path: path.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save exchange rates to disk as JSON.
|
||||
pub fn save(&self, rates: &ExchangeRates) -> Result<(), CurrencyError> {
|
||||
let json = serde_json::to_string_pretty(rates)
|
||||
.map_err(|e| CurrencyError::CacheError(format!("Failed to serialize rates: {}", e)))?;
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = Path::new(&self.path).parent() {
|
||||
if !parent.as_os_str().is_empty() {
|
||||
fs::create_dir_all(parent).map_err(|e| {
|
||||
CurrencyError::CacheError(format!(
|
||||
"Failed to create cache directory: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
fs::write(&self.path, json)
|
||||
.map_err(|e| CurrencyError::CacheError(format!("Failed to write cache file: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load exchange rates from disk.
|
||||
/// Returns None if the cache file doesn't exist or is empty.
|
||||
pub fn load(&self) -> Result<Option<ExchangeRates>, CurrencyError> {
|
||||
let path = Path::new(&self.path);
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let json = fs::read_to_string(path)
|
||||
.map_err(|e| CurrencyError::CacheError(format!("Failed to read cache file: {}", e)))?;
|
||||
|
||||
if json.trim().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let rates: ExchangeRates = serde_json::from_str(&json).map_err(|e| {
|
||||
CurrencyError::CacheError(format!("Failed to parse cache file: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Some(rates))
|
||||
}
|
||||
|
||||
/// Check whether cached rates are stale based on the given threshold.
|
||||
pub fn is_stale(&self, rates: &ExchangeRates, threshold: Duration) -> bool {
|
||||
let elapsed = Utc::now()
|
||||
.signed_duration_since(rates.timestamp)
|
||||
.num_seconds()
|
||||
.max(0) as u64;
|
||||
elapsed >= threshold.as_secs()
|
||||
}
|
||||
|
||||
/// Check if the cache file exists.
|
||||
pub fn exists(&self) -> bool {
|
||||
Path::new(&self.path).exists()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
fn sample_rates() -> ExchangeRates {
|
||||
let mut rates = HashMap::new();
|
||||
rates.insert("EUR".to_string(), 0.85);
|
||||
rates.insert("GBP".to_string(), 0.73);
|
||||
rates.insert("JPY".to_string(), 110.0);
|
||||
|
||||
ExchangeRates {
|
||||
base: "USD".to_string(),
|
||||
rates,
|
||||
timestamp: Utc::now(),
|
||||
provider: "test".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let tmp = NamedTempFile::new().unwrap();
|
||||
let path = tmp.path().to_str().unwrap();
|
||||
|
||||
let cache = ExchangeRateCache::new(path);
|
||||
let rates = sample_rates();
|
||||
|
||||
cache.save(&rates).unwrap();
|
||||
let loaded = cache.load().unwrap().unwrap();
|
||||
|
||||
assert_eq!(loaded.base, "USD");
|
||||
assert_eq!(loaded.provider, "test");
|
||||
assert_eq!(loaded.rates.len(), 3);
|
||||
assert!((loaded.rates["EUR"] - 0.85).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent() {
|
||||
let cache = ExchangeRateCache::new("/tmp/calcpad_nonexistent_test_cache_12345.json");
|
||||
let result = cache.load().unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_stale_fresh() {
|
||||
let rates = sample_rates(); // timestamp = now
|
||||
let cache = ExchangeRateCache::new("/tmp/test_stale.json");
|
||||
|
||||
assert!(!cache.is_stale(&rates, Duration::from_secs(3600)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_stale_old() {
|
||||
let mut rates = sample_rates();
|
||||
rates.timestamp = Utc::now() - chrono::Duration::hours(2);
|
||||
let cache = ExchangeRateCache::new("/tmp/test_stale.json");
|
||||
|
||||
assert!(cache.is_stale(&rates, Duration::from_secs(3600)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exists() {
|
||||
let tmp = NamedTempFile::new().unwrap();
|
||||
let path = tmp.path().to_str().unwrap();
|
||||
let cache = ExchangeRateCache::new(path);
|
||||
assert!(cache.exists());
|
||||
|
||||
let cache2 = ExchangeRateCache::new("/tmp/calcpad_no_such_file_99999.json");
|
||||
assert!(!cache2.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_default() {
|
||||
let config = ProviderConfig::default();
|
||||
assert!(config.api_key.is_none());
|
||||
assert_eq!(config.staleness_threshold, Duration::from_secs(3600));
|
||||
assert!(!config.cache_path.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_display_live() {
|
||||
let metadata = RateMetadata {
|
||||
updated_at: Utc::now(),
|
||||
source: RateSource::Live,
|
||||
provider: "test".to_string(),
|
||||
};
|
||||
let display = metadata.display_status();
|
||||
assert!(display.starts_with("rates updated "));
|
||||
assert!(display.contains("just now"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_display_offline() {
|
||||
let metadata = RateMetadata {
|
||||
updated_at: Utc::now() - chrono::Duration::hours(2),
|
||||
source: RateSource::Offline,
|
||||
provider: "test".to_string(),
|
||||
};
|
||||
let display = metadata.display_status();
|
||||
assert!(display.starts_with("offline -- rates from "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_display_minutes_ago() {
|
||||
let metadata = RateMetadata {
|
||||
updated_at: Utc::now() - chrono::Duration::minutes(5),
|
||||
source: RateSource::Cached,
|
||||
provider: "test".to_string(),
|
||||
};
|
||||
let display = metadata.display_status();
|
||||
assert!(display.contains("minute"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_display_hours_ago() {
|
||||
let metadata = RateMetadata {
|
||||
updated_at: Utc::now() - chrono::Duration::hours(3),
|
||||
source: RateSource::Live,
|
||||
provider: "test".to_string(),
|
||||
};
|
||||
let display = metadata.display_status();
|
||||
assert!(display.contains("hour"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_display_days_ago() {
|
||||
let metadata = RateMetadata {
|
||||
updated_at: Utc::now() - chrono::Duration::days(2),
|
||||
source: RateSource::Cached,
|
||||
provider: "test".to_string(),
|
||||
};
|
||||
let display = metadata.display_status();
|
||||
assert!(display.contains("day"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exchange_rates_serialization_roundtrip() {
|
||||
let rates = sample_rates();
|
||||
let json = serde_json::to_string(&rates).unwrap();
|
||||
let deserialized: ExchangeRates = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.base, "USD");
|
||||
assert_eq!(deserialized.rates.len(), 3);
|
||||
assert!((deserialized.rates["EUR"] - 0.85).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
427
calcpad-engine/src/currency/symbols.rs
Normal file
427
calcpad-engine/src/currency/symbols.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
//! Currency symbol, code, and alias resolution.
|
||||
//!
|
||||
//! Resolves currency symbols ($, EUR, R$), ISO 4217 codes (USD, EUR, GBP),
|
||||
//! natural-language aliases (dollars, euros, pounds), and cryptocurrency
|
||||
//! symbols (BTC, ETH) to their canonical identifiers.
|
||||
|
||||
use crate::currency::crypto;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Symbol -> ISO code
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Resolve a currency symbol string (e.g., "$", "R$") to its ISO 4217 code.
|
||||
pub fn resolve_symbol(symbol: &str) -> Option<&'static str> {
|
||||
match symbol {
|
||||
"$" | "US$" => Some("USD"),
|
||||
"€" => Some("EUR"),
|
||||
"£" => Some("GBP"),
|
||||
"¥" => Some("JPY"),
|
||||
"R$" => Some("BRL"),
|
||||
"₹" => Some("INR"),
|
||||
"₩" => Some("KRW"),
|
||||
"₽" => Some("RUB"),
|
||||
"₺" => Some("TRY"),
|
||||
"₴" => Some("UAH"),
|
||||
"₱" => Some("PHP"),
|
||||
"฿" => Some("THB"),
|
||||
"₫" => Some("VND"),
|
||||
"₦" => Some("NGN"),
|
||||
"₡" => Some("CRC"),
|
||||
"₵" => Some("GHS"),
|
||||
"₸" => Some("KZT"),
|
||||
"₮" => Some("MNT"),
|
||||
"₪" => Some("ILS"),
|
||||
"kr" => Some("SEK"), // ambiguous, default to SEK
|
||||
"C$" => Some("CAD"),
|
||||
"A$" => Some("AUD"),
|
||||
"NZ$" => Some("NZD"),
|
||||
"HK$" => Some("HKD"),
|
||||
"S$" => Some("SGD"),
|
||||
"NT$" => Some("TWD"),
|
||||
"MX$" => Some("MXN"),
|
||||
"zl" | "zł" => Some("PLN"),
|
||||
"Ft" => Some("HUF"),
|
||||
"Kc" | "Kč" => Some("CZK"),
|
||||
"Rp" => Some("IDR"),
|
||||
"RM" => Some("MYR"),
|
||||
"CHF" => Some("CHF"), // symbol == code for Swiss franc
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Alias -> ISO code (natural language)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Resolve a natural-language alias to its ISO 4217 code (case-insensitive).
|
||||
pub fn resolve_alias(alias: &str) -> Option<&'static str> {
|
||||
match alias.to_lowercase().as_str() {
|
||||
"dollar" | "dollars" | "buck" | "bucks" => Some("USD"),
|
||||
"euro" | "euros" => Some("EUR"),
|
||||
"pound" | "pounds" | "quid" => Some("GBP"),
|
||||
"yen" => Some("JPY"),
|
||||
"yuan" | "renminbi" | "rmb" => Some("CNY"),
|
||||
"real" | "reais" => Some("BRL"),
|
||||
"rupee" | "rupees" => Some("INR"),
|
||||
"franc" | "francs" => Some("CHF"),
|
||||
"krona" | "kronor" => Some("SEK"),
|
||||
"krone" | "kroner" => Some("NOK"),
|
||||
"won" => Some("KRW"),
|
||||
"lira" => Some("TRY"),
|
||||
"ruble" | "rubles" | "rouble" | "roubles" => Some("RUB"),
|
||||
"ringgit" => Some("MYR"),
|
||||
"baht" => Some("THB"),
|
||||
"peso" | "pesos" => Some("MXN"),
|
||||
"rand" => Some("ZAR"),
|
||||
"shekel" | "shekels" => Some("ILS"),
|
||||
"zloty" => Some("PLN"),
|
||||
"forint" => Some("HUF"),
|
||||
"koruna" => Some("CZK"),
|
||||
"dirham" | "dirhams" => Some("AED"),
|
||||
"riyal" | "riyals" => Some("SAR"),
|
||||
"bitcoin" | "btc" | "satoshi" | "sats" => Some("BTC"),
|
||||
"ether" | "ethereum" => Some("ETH"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ISO 4217 code validation (fiat)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Comprehensive set of recognized fiat ISO 4217 currency codes.
|
||||
/// This includes 180+ currencies actively traded or in circulation.
|
||||
pub fn is_currency_code(code: &str) -> bool {
|
||||
matches!(
|
||||
code,
|
||||
// Major / G10
|
||||
"USD" | "EUR" | "GBP" | "JPY" | "CHF" | "CAD" | "AUD" | "NZD"
|
||||
// Asia
|
||||
| "CNY" | "HKD" | "SGD" | "TWD" | "KRW" | "INR" | "PKR"
|
||||
| "BDT" | "LKR" | "NPR" | "THB" | "MYR" | "IDR" | "PHP"
|
||||
| "VND" | "KHR" | "LAK" | "MMK" | "BND" | "MVR" | "MNT"
|
||||
// Middle East
|
||||
| "TRY" | "ILS" | "AED" | "SAR" | "QAR" | "BHD" | "OMR"
|
||||
| "KWD" | "JOD" | "LBP" | "IQD" | "IRR" | "YER" | "SYP"
|
||||
// Eastern Europe / CIS
|
||||
| "RUB" | "PLN" | "CZK" | "HUF" | "RON" | "BGN" | "HRK"
|
||||
| "ISK" | "UAH" | "RSD" | "BAM" | "MKD" | "ALL" | "MDL"
|
||||
| "GEL" | "AMD" | "AZN" | "BYN"
|
||||
// Scandinavia
|
||||
| "SEK" | "NOK" | "DKK"
|
||||
// Americas
|
||||
| "MXN" | "BRL" | "ARS" | "CLP" | "COP" | "PEN" | "UYU"
|
||||
| "PYG" | "BOB" | "VES" | "CRC" | "GTQ" | "HNL" | "NIO"
|
||||
| "PAB" | "DOP" | "JMD" | "TTD" | "HTG" | "BBD" | "BSD"
|
||||
| "BZD" | "GYD" | "SRD" | "AWG" | "ANG" | "BMD" | "KYD"
|
||||
| "CUP" | "CUC" | "XCD"
|
||||
// Africa
|
||||
| "ZAR" | "EGP" | "NGN" | "KES" | "GHS" | "TZS" | "UGX"
|
||||
| "ETB" | "MAD" | "TND" | "DZD" | "LYD" | "XOF" | "XAF"
|
||||
| "CDF" | "AOA" | "MZN" | "ZMW" | "BWP" | "MWK" | "RWF"
|
||||
| "SOS" | "SDG" | "SCR" | "MUR" | "GMD" | "SLL" | "GNF"
|
||||
| "CVE" | "NAD" | "SZL" | "LSL" | "BIF" | "DJF" | "ERN"
|
||||
| "STN" | "KMF" | "MGA"
|
||||
// Pacific
|
||||
| "FJD" | "PGK" | "WST" | "TOP" | "VUV" | "SBD"
|
||||
// Central Asia
|
||||
| "KZT" | "UZS" | "KGS" | "TJS" | "TMT"
|
||||
// Other / special
|
||||
| "AFN" | "BTN" | "XDR" | "XAU" | "XAG"
|
||||
)
|
||||
}
|
||||
|
||||
/// Return a `&'static str` for a validated fiat currency code.
|
||||
fn resolve_code_static(code: &str) -> Option<&'static str> {
|
||||
// This is a macro-like approach to avoid repeating the giant list.
|
||||
// We match every code we know and return its static reference.
|
||||
match code {
|
||||
"USD" => Some("USD"), "EUR" => Some("EUR"), "GBP" => Some("GBP"),
|
||||
"JPY" => Some("JPY"), "CHF" => Some("CHF"), "CAD" => Some("CAD"),
|
||||
"AUD" => Some("AUD"), "NZD" => Some("NZD"), "CNY" => Some("CNY"),
|
||||
"HKD" => Some("HKD"), "SGD" => Some("SGD"), "TWD" => Some("TWD"),
|
||||
"KRW" => Some("KRW"), "INR" => Some("INR"), "PKR" => Some("PKR"),
|
||||
"BDT" => Some("BDT"), "LKR" => Some("LKR"), "NPR" => Some("NPR"),
|
||||
"THB" => Some("THB"), "MYR" => Some("MYR"), "IDR" => Some("IDR"),
|
||||
"PHP" => Some("PHP"), "VND" => Some("VND"), "KHR" => Some("KHR"),
|
||||
"LAK" => Some("LAK"), "MMK" => Some("MMK"), "BND" => Some("BND"),
|
||||
"MVR" => Some("MVR"), "MNT" => Some("MNT"), "TRY" => Some("TRY"),
|
||||
"ILS" => Some("ILS"), "AED" => Some("AED"), "SAR" => Some("SAR"),
|
||||
"QAR" => Some("QAR"), "BHD" => Some("BHD"), "OMR" => Some("OMR"),
|
||||
"KWD" => Some("KWD"), "JOD" => Some("JOD"), "LBP" => Some("LBP"),
|
||||
"IQD" => Some("IQD"), "IRR" => Some("IRR"), "YER" => Some("YER"),
|
||||
"SYP" => Some("SYP"), "RUB" => Some("RUB"), "PLN" => Some("PLN"),
|
||||
"CZK" => Some("CZK"), "HUF" => Some("HUF"), "RON" => Some("RON"),
|
||||
"BGN" => Some("BGN"), "HRK" => Some("HRK"), "ISK" => Some("ISK"),
|
||||
"UAH" => Some("UAH"), "RSD" => Some("RSD"), "BAM" => Some("BAM"),
|
||||
"MKD" => Some("MKD"), "ALL" => Some("ALL"), "MDL" => Some("MDL"),
|
||||
"GEL" => Some("GEL"), "AMD" => Some("AMD"), "AZN" => Some("AZN"),
|
||||
"BYN" => Some("BYN"), "SEK" => Some("SEK"), "NOK" => Some("NOK"),
|
||||
"DKK" => Some("DKK"), "MXN" => Some("MXN"), "BRL" => Some("BRL"),
|
||||
"ARS" => Some("ARS"), "CLP" => Some("CLP"), "COP" => Some("COP"),
|
||||
"PEN" => Some("PEN"), "UYU" => Some("UYU"), "PYG" => Some("PYG"),
|
||||
"BOB" => Some("BOB"), "VES" => Some("VES"), "CRC" => Some("CRC"),
|
||||
"GTQ" => Some("GTQ"), "HNL" => Some("HNL"), "NIO" => Some("NIO"),
|
||||
"PAB" => Some("PAB"), "DOP" => Some("DOP"), "JMD" => Some("JMD"),
|
||||
"TTD" => Some("TTD"), "HTG" => Some("HTG"), "BBD" => Some("BBD"),
|
||||
"BSD" => Some("BSD"), "BZD" => Some("BZD"), "GYD" => Some("GYD"),
|
||||
"SRD" => Some("SRD"), "AWG" => Some("AWG"), "ANG" => Some("ANG"),
|
||||
"BMD" => Some("BMD"), "KYD" => Some("KYD"), "CUP" => Some("CUP"),
|
||||
"CUC" => Some("CUC"), "XCD" => Some("XCD"), "ZAR" => Some("ZAR"),
|
||||
"EGP" => Some("EGP"), "NGN" => Some("NGN"), "KES" => Some("KES"),
|
||||
"GHS" => Some("GHS"), "TZS" => Some("TZS"), "UGX" => Some("UGX"),
|
||||
"ETB" => Some("ETB"), "MAD" => Some("MAD"), "TND" => Some("TND"),
|
||||
"DZD" => Some("DZD"), "LYD" => Some("LYD"), "XOF" => Some("XOF"),
|
||||
"XAF" => Some("XAF"), "CDF" => Some("CDF"), "AOA" => Some("AOA"),
|
||||
"MZN" => Some("MZN"), "ZMW" => Some("ZMW"), "BWP" => Some("BWP"),
|
||||
"MWK" => Some("MWK"), "RWF" => Some("RWF"), "SOS" => Some("SOS"),
|
||||
"SDG" => Some("SDG"), "SCR" => Some("SCR"), "MUR" => Some("MUR"),
|
||||
"GMD" => Some("GMD"), "SLL" => Some("SLL"), "GNF" => Some("GNF"),
|
||||
"CVE" => Some("CVE"), "NAD" => Some("NAD"), "SZL" => Some("SZL"),
|
||||
"LSL" => Some("LSL"), "BIF" => Some("BIF"), "DJF" => Some("DJF"),
|
||||
"ERN" => Some("ERN"), "STN" => Some("STN"), "KMF" => Some("KMF"),
|
||||
"MGA" => Some("MGA"), "FJD" => Some("FJD"), "PGK" => Some("PGK"),
|
||||
"WST" => Some("WST"), "TOP" => Some("TOP"), "VUV" => Some("VUV"),
|
||||
"SBD" => Some("SBD"), "KZT" => Some("KZT"), "UZS" => Some("UZS"),
|
||||
"KGS" => Some("KGS"), "TJS" => Some("TJS"), "TMT" => Some("TMT"),
|
||||
"AFN" => Some("AFN"), "BTN" => Some("BTN"), "XDR" => Some("XDR"),
|
||||
"XAU" => Some("XAU"), "XAG" => Some("XAG"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Crypto symbol validation (re-exported from crypto module)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Check if a symbol is a known cryptocurrency (case-insensitive).
|
||||
pub fn is_crypto_symbol(symbol: &str) -> bool {
|
||||
crypto::is_known_crypto(symbol)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Unified resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Resolve any currency reference -- symbol ($, EUR), ISO code, natural-language
|
||||
/// alias, or crypto symbol -- to a canonical uppercase code.
|
||||
///
|
||||
/// Resolution order:
|
||||
/// 1. Currency symbol (e.g. "$" -> "USD")
|
||||
/// 2. Exact fiat ISO code (e.g. "EUR" -> "EUR", case-insensitive)
|
||||
/// 3. Crypto symbol (e.g. "BTC" -> "BTC", case-insensitive)
|
||||
/// 4. Natural-language alias (e.g. "dollars" -> "USD")
|
||||
///
|
||||
/// Returns None if the input is not recognized.
|
||||
pub fn resolve_currency(input: &str) -> Option<&'static str> {
|
||||
// 1. Try symbol first (handles "$", "EUR", "R$", etc.)
|
||||
if let Some(code) = resolve_symbol(input) {
|
||||
return Some(code);
|
||||
}
|
||||
|
||||
// 2. Try exact fiat ISO code (case-insensitive)
|
||||
let upper = input.to_uppercase();
|
||||
if let Some(code) = resolve_code_static(&upper) {
|
||||
return Some(code);
|
||||
}
|
||||
|
||||
// 3. Try crypto symbol
|
||||
if is_crypto_symbol(&upper) {
|
||||
// Return a static str for the matched crypto
|
||||
return resolve_crypto_static(&upper);
|
||||
}
|
||||
|
||||
// 4. Try natural-language alias
|
||||
resolve_alias(input)
|
||||
}
|
||||
|
||||
/// Return a static str for known crypto symbols.
|
||||
fn resolve_crypto_static(symbol: &str) -> Option<&'static str> {
|
||||
crypto::CRYPTO_SYMBOLS
|
||||
.iter()
|
||||
.find(|(s, _)| *s == symbol)
|
||||
.map(|(s, _)| *s)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// --- Symbol resolution ---
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_usd() {
|
||||
assert_eq!(resolve_symbol("$"), Some("USD"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_eur() {
|
||||
assert_eq!(resolve_symbol("€"), Some("EUR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_gbp() {
|
||||
assert_eq!(resolve_symbol("£"), Some("GBP"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_jpy() {
|
||||
assert_eq!(resolve_symbol("¥"), Some("JPY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_brl() {
|
||||
assert_eq!(resolve_symbol("R$"), Some("BRL"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_inr() {
|
||||
assert_eq!(resolve_symbol("₹"), Some("INR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_krw() {
|
||||
assert_eq!(resolve_symbol("₩"), Some("KRW"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_prefixed_dollar() {
|
||||
assert_eq!(resolve_symbol("C$"), Some("CAD"));
|
||||
assert_eq!(resolve_symbol("A$"), Some("AUD"));
|
||||
assert_eq!(resolve_symbol("NZ$"), Some("NZD"));
|
||||
assert_eq!(resolve_symbol("HK$"), Some("HKD"));
|
||||
assert_eq!(resolve_symbol("S$"), Some("SGD"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_symbol_unknown() {
|
||||
assert_eq!(resolve_symbol("X"), None);
|
||||
assert_eq!(resolve_symbol(""), None);
|
||||
}
|
||||
|
||||
// --- Alias resolution ---
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias_dollars() {
|
||||
assert_eq!(resolve_alias("dollars"), Some("USD"));
|
||||
assert_eq!(resolve_alias("dollar"), Some("USD"));
|
||||
assert_eq!(resolve_alias("Dollars"), Some("USD"));
|
||||
assert_eq!(resolve_alias("bucks"), Some("USD"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias_euros() {
|
||||
assert_eq!(resolve_alias("euros"), Some("EUR"));
|
||||
assert_eq!(resolve_alias("euro"), Some("EUR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias_pounds() {
|
||||
assert_eq!(resolve_alias("pounds"), Some("GBP"));
|
||||
assert_eq!(resolve_alias("pound"), Some("GBP"));
|
||||
assert_eq!(resolve_alias("quid"), Some("GBP"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias_yen() {
|
||||
assert_eq!(resolve_alias("yen"), Some("JPY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias_crypto() {
|
||||
assert_eq!(resolve_alias("bitcoin"), Some("BTC"));
|
||||
assert_eq!(resolve_alias("ether"), Some("ETH"));
|
||||
assert_eq!(resolve_alias("ethereum"), Some("ETH"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias_unknown() {
|
||||
assert_eq!(resolve_alias("foo"), None);
|
||||
}
|
||||
|
||||
// --- ISO code validation ---
|
||||
|
||||
#[test]
|
||||
fn test_is_currency_code_major() {
|
||||
assert!(is_currency_code("USD"));
|
||||
assert!(is_currency_code("EUR"));
|
||||
assert!(is_currency_code("GBP"));
|
||||
assert!(is_currency_code("JPY"));
|
||||
assert!(is_currency_code("CHF"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_currency_code_regional() {
|
||||
assert!(is_currency_code("BRL"));
|
||||
assert!(is_currency_code("MXN"));
|
||||
assert!(is_currency_code("ZAR"));
|
||||
assert!(is_currency_code("NGN"));
|
||||
assert!(is_currency_code("KES"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_currency_code_negative() {
|
||||
assert!(!is_currency_code("usd")); // lowercase
|
||||
assert!(!is_currency_code("XYZ"));
|
||||
assert!(!is_currency_code("kg"));
|
||||
assert!(!is_currency_code(""));
|
||||
}
|
||||
|
||||
// --- Unified resolution ---
|
||||
|
||||
#[test]
|
||||
fn test_resolve_currency_from_symbol() {
|
||||
assert_eq!(resolve_currency("$"), Some("USD"));
|
||||
assert_eq!(resolve_currency("€"), Some("EUR"));
|
||||
assert_eq!(resolve_currency("R$"), Some("BRL"));
|
||||
assert_eq!(resolve_currency("₹"), Some("INR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_currency_from_code() {
|
||||
assert_eq!(resolve_currency("USD"), Some("USD"));
|
||||
assert_eq!(resolve_currency("EUR"), Some("EUR"));
|
||||
assert_eq!(resolve_currency("usd"), Some("USD"));
|
||||
assert_eq!(resolve_currency("eur"), Some("EUR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_currency_from_crypto() {
|
||||
assert_eq!(resolve_currency("BTC"), Some("BTC"));
|
||||
assert_eq!(resolve_currency("ETH"), Some("ETH"));
|
||||
assert_eq!(resolve_currency("btc"), Some("BTC"));
|
||||
assert_eq!(resolve_currency("sol"), Some("SOL"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_currency_from_alias() {
|
||||
assert_eq!(resolve_currency("dollars"), Some("USD"));
|
||||
assert_eq!(resolve_currency("euros"), Some("EUR"));
|
||||
assert_eq!(resolve_currency("pounds"), Some("GBP"));
|
||||
assert_eq!(resolve_currency("yen"), Some("JPY"));
|
||||
assert_eq!(resolve_currency("bitcoin"), Some("BTC"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_currency_unknown() {
|
||||
assert_eq!(resolve_currency("foobar"), None);
|
||||
assert_eq!(resolve_currency("kg"), None);
|
||||
assert_eq!(resolve_currency("meters"), None);
|
||||
}
|
||||
|
||||
// --- Crypto symbol ---
|
||||
|
||||
#[test]
|
||||
fn test_is_crypto_symbol_positive() {
|
||||
assert!(is_crypto_symbol("BTC"));
|
||||
assert!(is_crypto_symbol("ETH"));
|
||||
assert!(is_crypto_symbol("btc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_crypto_symbol_negative() {
|
||||
assert!(!is_crypto_symbol("USD"));
|
||||
assert!(!is_crypto_symbol("FOOBAR"));
|
||||
}
|
||||
}
|
||||
390
calcpad-engine/src/datetime/business_days.rs
Normal file
390
calcpad-engine/src/datetime/business_days.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
//! Business day calculations: skip weekends, configurable holiday calendars,
|
||||
//! forward and backward counting.
|
||||
|
||||
use chrono::{Datelike, NaiveDate, Weekday};
|
||||
|
||||
/// Configuration for business day calculations.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BusinessDayConfig {
|
||||
/// Specific dates to treat as holidays (non-business days).
|
||||
/// An empty list means only weekends are skipped.
|
||||
pub holidays: Vec<NaiveDate>,
|
||||
}
|
||||
|
||||
impl Default for BusinessDayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
holidays: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BusinessDayConfig {
|
||||
/// Check whether a given date is a business day.
|
||||
pub fn is_business_day(&self, date: NaiveDate) -> bool {
|
||||
let wd = date.weekday();
|
||||
if wd == Weekday::Sat || wd == Weekday::Sun {
|
||||
return false;
|
||||
}
|
||||
!self.holidays.contains(&date)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add `count` business days to `start`, skipping weekends and holidays.
|
||||
/// Counting begins on the day **after** `start`.
|
||||
pub fn add_business_days(
|
||||
start: NaiveDate,
|
||||
count: i64,
|
||||
config: &BusinessDayConfig,
|
||||
) -> Option<NaiveDate> {
|
||||
if count == 0 {
|
||||
return Some(start);
|
||||
}
|
||||
|
||||
let mut remaining = count;
|
||||
let mut current = start;
|
||||
|
||||
while remaining > 0 {
|
||||
current = current.checked_add_signed(chrono::Duration::days(1))?;
|
||||
if config.is_business_day(current) {
|
||||
remaining -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
Some(current)
|
||||
}
|
||||
|
||||
/// Subtract `count` business days from `start` (go backward).
|
||||
/// Counting begins on the day **before** `start`.
|
||||
pub fn sub_business_days(
|
||||
start: NaiveDate,
|
||||
count: i64,
|
||||
config: &BusinessDayConfig,
|
||||
) -> Option<NaiveDate> {
|
||||
if count == 0 {
|
||||
return Some(start);
|
||||
}
|
||||
|
||||
let mut remaining = count;
|
||||
let mut current = start;
|
||||
|
||||
while remaining > 0 {
|
||||
current = current.checked_sub_signed(chrono::Duration::days(1))?;
|
||||
if config.is_business_day(current) {
|
||||
remaining -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
Some(current)
|
||||
}
|
||||
|
||||
/// Count the number of business days between two dates (exclusive of endpoints,
|
||||
/// or inclusive depending on convention -- here we count days strictly between
|
||||
/// `from` and `to`, not including `from` but including `to`).
|
||||
pub fn business_days_between(
|
||||
from: NaiveDate,
|
||||
to: NaiveDate,
|
||||
config: &BusinessDayConfig,
|
||||
) -> i64 {
|
||||
if from >= to {
|
||||
return 0;
|
||||
}
|
||||
let mut count = 0i64;
|
||||
let mut current = from;
|
||||
while current < to {
|
||||
current += chrono::Duration::days(1);
|
||||
if config.is_business_day(current) {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// US Federal Holiday Calendar
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Generate US federal holiday dates for a given year.
|
||||
///
|
||||
/// Includes: New Year's Day, MLK Day, Presidents' Day, Memorial Day,
|
||||
/// Independence Day, Labor Day, Columbus Day, Veterans Day, Thanksgiving, Christmas.
|
||||
pub fn us_holidays(year: i32) -> Vec<NaiveDate> {
|
||||
let mut holidays = Vec::new();
|
||||
|
||||
// New Year's Day
|
||||
if let Some(d) = NaiveDate::from_ymd_opt(year, 1, 1) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// MLK Day -- 3rd Monday of January
|
||||
if let Some(d) = nth_weekday_of_month(year, 1, Weekday::Mon, 3) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Presidents' Day -- 3rd Monday of February
|
||||
if let Some(d) = nth_weekday_of_month(year, 2, Weekday::Mon, 3) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Memorial Day -- last Monday of May
|
||||
if let Some(d) = last_weekday_of_month(year, 5, Weekday::Mon) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Independence Day
|
||||
if let Some(d) = NaiveDate::from_ymd_opt(year, 7, 4) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Labor Day -- 1st Monday of September
|
||||
if let Some(d) = nth_weekday_of_month(year, 9, Weekday::Mon, 1) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Columbus Day -- 2nd Monday of October
|
||||
if let Some(d) = nth_weekday_of_month(year, 10, Weekday::Mon, 2) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Veterans Day
|
||||
if let Some(d) = NaiveDate::from_ymd_opt(year, 11, 11) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Thanksgiving -- 4th Thursday of November
|
||||
if let Some(d) = nth_weekday_of_month(year, 11, Weekday::Thu, 4) {
|
||||
holidays.push(d);
|
||||
}
|
||||
// Christmas Day
|
||||
if let Some(d) = NaiveDate::from_ymd_opt(year, 12, 25) {
|
||||
holidays.push(d);
|
||||
}
|
||||
|
||||
holidays
|
||||
}
|
||||
|
||||
/// Find the Nth occurrence of a weekday in a given month (1-indexed).
|
||||
pub fn nth_weekday_of_month(
|
||||
year: i32,
|
||||
month: u32,
|
||||
weekday: Weekday,
|
||||
n: u32,
|
||||
) -> Option<NaiveDate> {
|
||||
let first = NaiveDate::from_ymd_opt(year, month, 1)?;
|
||||
let first_wd = first.weekday();
|
||||
let days_until = (weekday.num_days_from_monday() as i32
|
||||
- first_wd.num_days_from_monday() as i32
|
||||
+ 7)
|
||||
% 7;
|
||||
let day = 1 + days_until as u32 + (n - 1) * 7;
|
||||
NaiveDate::from_ymd_opt(year, month, day)
|
||||
}
|
||||
|
||||
/// Find the last occurrence of a weekday in a given month.
|
||||
pub fn last_weekday_of_month(year: i32, month: u32, weekday: Weekday) -> Option<NaiveDate> {
|
||||
let next_month = if month == 12 {
|
||||
NaiveDate::from_ymd_opt(year + 1, 1, 1)?
|
||||
} else {
|
||||
NaiveDate::from_ymd_opt(year, month + 1, 1)?
|
||||
};
|
||||
let last_day = next_month.pred_opt()?;
|
||||
let last_wd = last_day.weekday();
|
||||
let days_back = (last_wd.num_days_from_monday() as i32
|
||||
- weekday.num_days_from_monday() as i32
|
||||
+ 7)
|
||||
% 7;
|
||||
let day = last_day.day() - days_back as u32;
|
||||
NaiveDate::from_ymd_opt(year, month, day)
|
||||
}
|
||||
|
||||
/// Resolve the next occurrence of a named weekday after `reference`.
|
||||
/// If `reference` IS that day, returns the **following** week's occurrence.
|
||||
pub fn next_weekday(name: &str, reference: NaiveDate) -> Option<NaiveDate> {
|
||||
let target = parse_weekday(name)?;
|
||||
let today_wd = reference.weekday();
|
||||
let days_ahead = (target.num_days_from_monday() as i32
|
||||
- today_wd.num_days_from_monday() as i32
|
||||
+ 7)
|
||||
% 7;
|
||||
let days_ahead = if days_ahead == 0 { 7 } else { days_ahead };
|
||||
reference.checked_add_signed(chrono::Duration::days(days_ahead as i64))
|
||||
}
|
||||
|
||||
/// Parse a weekday name (full or abbreviated, case-insensitive).
|
||||
pub fn parse_weekday(name: &str) -> Option<Weekday> {
|
||||
match name.to_lowercase().as_str() {
|
||||
"monday" | "mon" => Some(Weekday::Mon),
|
||||
"tuesday" | "tue" | "tues" => Some(Weekday::Tue),
|
||||
"wednesday" | "wed" => Some(Weekday::Wed),
|
||||
"thursday" | "thu" | "thur" | "thurs" => Some(Weekday::Thu),
|
||||
"friday" | "fri" => Some(Weekday::Fri),
|
||||
"saturday" | "sat" => Some(Weekday::Sat),
|
||||
"sunday" | "sun" => Some(Weekday::Sun),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn d(y: i32, m: u32, day: u32) -> NaiveDate {
|
||||
NaiveDate::from_ymd_opt(y, m, day).unwrap()
|
||||
}
|
||||
|
||||
// -- is_business_day --
|
||||
|
||||
#[test]
|
||||
fn test_weekday_is_business_day() {
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert!(cfg.is_business_day(d(2026, 3, 17))); // Tuesday
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weekend_not_business_day() {
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert!(!cfg.is_business_day(d(2026, 3, 21))); // Saturday
|
||||
assert!(!cfg.is_business_day(d(2026, 3, 22))); // Sunday
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_holiday_not_business_day() {
|
||||
let cfg = BusinessDayConfig {
|
||||
holidays: vec![d(2025, 12, 25)],
|
||||
};
|
||||
assert!(!cfg.is_business_day(d(2025, 12, 25))); // Thursday Christmas
|
||||
}
|
||||
|
||||
// -- add_business_days --
|
||||
|
||||
#[test]
|
||||
fn test_add_10_from_monday() {
|
||||
// March 16 2026 is Monday. 10 business days forward:
|
||||
// Day 1-5: Tue 17 ... Mon 23 (skipping Sat/Sun 21-22)
|
||||
// Wait: Day1=Tue17, Day2=Wed18, Day3=Thu19, Day4=Fri20,
|
||||
// skip Sat21 Sun22,
|
||||
// Day5=Mon23, Day6=Tue24, Day7=Wed25, Day8=Thu26, Day9=Fri27,
|
||||
// skip Sat28 Sun29,
|
||||
// Day10=Mon30
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(add_business_days(d(2026, 3, 16), 10, &cfg), Some(d(2026, 3, 30)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_10_from_wednesday() {
|
||||
// March 18 2026 is Wednesday.
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(
|
||||
add_business_days(d(2026, 3, 18), 10, &cfg),
|
||||
Some(d(2026, 4, 1))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_zero() {
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(add_business_days(d(2026, 3, 18), 0, &cfg), Some(d(2026, 3, 18)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_from_saturday() {
|
||||
// Saturday March 21: 1 biz day → skip Sun → Mon 23
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(add_business_days(d(2026, 3, 21), 1, &cfg), Some(d(2026, 3, 23)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_with_holiday() {
|
||||
// Dec 23 2025 (Tue) + 3 biz days, Christmas Dec 25 is holiday
|
||||
// Day1=Wed24, skip Thu25 (holiday), Day2=Fri26, skip Sat27 Sun28, Day3=Mon29
|
||||
let cfg = BusinessDayConfig {
|
||||
holidays: vec![d(2025, 12, 25)],
|
||||
};
|
||||
assert_eq!(
|
||||
add_business_days(d(2025, 12, 23), 3, &cfg),
|
||||
Some(d(2025, 12, 29))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_no_holidays_christmas_counts() {
|
||||
// Same scenario but no holiday calendar
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(
|
||||
add_business_days(d(2025, 12, 23), 3, &cfg),
|
||||
Some(d(2025, 12, 26))
|
||||
);
|
||||
}
|
||||
|
||||
// -- sub_business_days --
|
||||
|
||||
#[test]
|
||||
fn test_sub_5_from_wednesday() {
|
||||
// March 18 Wed: Day1=Tue17, Day2=Mon16, skip Sun15 Sat14, Day3=Fri13, Day4=Thu12, Day5=Wed11
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(
|
||||
sub_business_days(d(2026, 3, 18), 5, &cfg),
|
||||
Some(d(2026, 3, 11))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_5_from_monday() {
|
||||
// March 16 Mon: Day1=Fri13, skip Sat14 Sun15 (already behind),
|
||||
// wait: Mon16 back 1 cal day = Sun15 (skip), Sat14 (skip), Fri13 (day1),
|
||||
// Thu12 (day2), Wed11 (day3), Tue10 (day4), Mon9 (day5)
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(
|
||||
sub_business_days(d(2026, 3, 16), 5, &cfg),
|
||||
Some(d(2026, 3, 9))
|
||||
);
|
||||
}
|
||||
|
||||
// -- business_days_between --
|
||||
|
||||
#[test]
|
||||
fn test_between_same_day() {
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(business_days_between(d(2026, 3, 17), d(2026, 3, 17), &cfg), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_between_one_week() {
|
||||
// Mon to next Mon = 5 business days
|
||||
let cfg = BusinessDayConfig::default();
|
||||
assert_eq!(
|
||||
business_days_between(d(2026, 3, 16), d(2026, 3, 23), &cfg),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
// -- us_holidays --
|
||||
|
||||
#[test]
|
||||
fn test_us_holidays_2025() {
|
||||
let holidays = us_holidays(2025);
|
||||
assert!(holidays.contains(&d(2025, 1, 1))); // New Year's
|
||||
assert!(holidays.contains(&d(2025, 7, 4))); // Independence Day
|
||||
assert!(holidays.contains(&d(2025, 12, 25))); // Christmas
|
||||
assert!(holidays.contains(&d(2025, 1, 20))); // MLK Day
|
||||
assert!(holidays.contains(&d(2025, 11, 27))); // Thanksgiving
|
||||
assert!(holidays.contains(&d(2025, 5, 26))); // Memorial Day
|
||||
assert!(holidays.contains(&d(2025, 9, 1))); // Labor Day
|
||||
}
|
||||
|
||||
// -- next_weekday --
|
||||
|
||||
#[test]
|
||||
fn test_next_friday_from_tuesday() {
|
||||
// March 17 2026 = Tuesday, next Friday = March 20
|
||||
assert_eq!(next_weekday("Friday", d(2026, 3, 17)), Some(d(2026, 3, 20)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_tuesday_from_tuesday() {
|
||||
// Same day goes to next week
|
||||
assert_eq!(next_weekday("Tuesday", d(2026, 3, 17)), Some(d(2026, 3, 24)));
|
||||
}
|
||||
|
||||
// -- parse_weekday --
|
||||
|
||||
#[test]
|
||||
fn test_parse_weekday() {
|
||||
assert_eq!(parse_weekday("Monday"), Some(Weekday::Mon));
|
||||
assert_eq!(parse_weekday("fri"), Some(Weekday::Fri));
|
||||
assert_eq!(parse_weekday("SUNDAY"), Some(Weekday::Sun));
|
||||
assert_eq!(parse_weekday("blurday"), None);
|
||||
}
|
||||
}
|
||||
434
calcpad-engine/src/datetime/date_math.rs
Normal file
434
calcpad-engine/src/datetime/date_math.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
//! Date arithmetic: today + 3 weeks, date ranges, days until X, named dates.
|
||||
|
||||
use chrono::{Datelike, Months, NaiveDate};
|
||||
|
||||
/// User preference for ambiguous date formats like 3/4/2025.
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DateFormat {
|
||||
/// MM/DD/YYYY (US default)
|
||||
#[default]
|
||||
US,
|
||||
/// DD/MM/YYYY (European)
|
||||
EU,
|
||||
}
|
||||
|
||||
/// A compound calendar duration (years, months, weeks, days).
|
||||
/// Stored as signed integers so negation is straightforward.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CalendarDuration {
|
||||
pub years: i64,
|
||||
pub months: i64,
|
||||
pub weeks: i64,
|
||||
pub days: i64,
|
||||
}
|
||||
|
||||
impl CalendarDuration {
|
||||
pub fn zero() -> Self {
|
||||
Self {
|
||||
years: 0,
|
||||
months: 0,
|
||||
weeks: 0,
|
||||
days: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Approximate total days (30-day months, 365-day years).
|
||||
pub fn total_days_approx(&self) -> i64 {
|
||||
self.years * 365 + self.months * 30 + self.weeks * 7 + self.days
|
||||
}
|
||||
|
||||
/// Return the negated duration.
|
||||
pub fn negate(&self) -> Self {
|
||||
Self {
|
||||
years: -self.years,
|
||||
months: -self.months,
|
||||
weeks: -self.weeks,
|
||||
days: -self.days,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CalendarDuration {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut parts = Vec::new();
|
||||
if self.years != 0 {
|
||||
let label = if self.years.abs() == 1 { "year" } else { "years" };
|
||||
parts.push(format!("{} {}", self.years.abs(), label));
|
||||
}
|
||||
if self.months != 0 {
|
||||
let label = if self.months.abs() == 1 {
|
||||
"month"
|
||||
} else {
|
||||
"months"
|
||||
};
|
||||
parts.push(format!("{} {}", self.months.abs(), label));
|
||||
}
|
||||
if self.weeks != 0 {
|
||||
let label = if self.weeks.abs() == 1 { "week" } else { "weeks" };
|
||||
parts.push(format!("{} {}", self.weeks.abs(), label));
|
||||
}
|
||||
if self.days != 0 {
|
||||
let label = if self.days.abs() == 1 { "day" } else { "days" };
|
||||
parts.push(format!("{} {}", self.days.abs(), label));
|
||||
}
|
||||
if parts.is_empty() {
|
||||
write!(f, "0 days")
|
||||
} else {
|
||||
write!(f, "{}", parts.join(" "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a `CalendarDuration` to a date. Months/years use `chrono::Months` for
|
||||
/// correct calendar arithmetic; weeks and days are added as simple day offsets.
|
||||
pub fn add_duration(date: NaiveDate, dur: &CalendarDuration) -> Option<NaiveDate> {
|
||||
let mut result = date;
|
||||
|
||||
// Years (converted to months for chrono)
|
||||
if dur.years != 0 {
|
||||
if dur.years > 0 {
|
||||
result = result.checked_add_months(Months::new((dur.years * 12) as u32))?;
|
||||
} else {
|
||||
result =
|
||||
result.checked_sub_months(Months::new((dur.years.abs() * 12) as u32))?;
|
||||
}
|
||||
}
|
||||
|
||||
// Months
|
||||
if dur.months != 0 {
|
||||
if dur.months > 0 {
|
||||
result = result.checked_add_months(Months::new(dur.months as u32))?;
|
||||
} else {
|
||||
result =
|
||||
result.checked_sub_months(Months::new(dur.months.unsigned_abs() as u32))?;
|
||||
}
|
||||
}
|
||||
|
||||
// Weeks + days
|
||||
let total_days = dur.weeks * 7 + dur.days;
|
||||
if total_days != 0 {
|
||||
result = result.checked_add_signed(chrono::Duration::days(total_days))?;
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Subtract a `CalendarDuration` from a date.
|
||||
pub fn sub_duration(date: NaiveDate, dur: &CalendarDuration) -> Option<NaiveDate> {
|
||||
add_duration(date, &dur.negate())
|
||||
}
|
||||
|
||||
/// Compute the signed difference in whole days: `to - from`.
|
||||
pub fn days_between(from: NaiveDate, to: NaiveDate) -> i64 {
|
||||
(to - from).num_days()
|
||||
}
|
||||
|
||||
/// Resolve a named date to the next occurrence on or after `reference`.
|
||||
/// Returns `None` for unrecognized names.
|
||||
pub fn resolve_named_date(name: &str, reference: NaiveDate) -> Option<NaiveDate> {
|
||||
let lower = name.to_lowercase();
|
||||
let (month, day) = match lower.as_str() {
|
||||
"christmas" | "xmas" => (12, 25),
|
||||
"newyear" | "newyears" | "new year" | "new years" | "new year's" => (1, 1),
|
||||
"valentines" | "valentine's" | "valentines day" => (2, 14),
|
||||
"halloween" => (10, 31),
|
||||
"independence day" | "july 4th" | "fourth of july" => (7, 4),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let this_year = NaiveDate::from_ymd_opt(reference.year(), month, day)?;
|
||||
if reference <= this_year {
|
||||
Some(this_year)
|
||||
} else {
|
||||
NaiveDate::from_ymd_opt(reference.year() + 1, month, day)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a month name (full or abbreviated) to its 1-based number.
|
||||
pub fn month_number(name: &str) -> Option<u32> {
|
||||
match name.to_lowercase().as_str() {
|
||||
"january" | "jan" => Some(1),
|
||||
"february" | "feb" => Some(2),
|
||||
"march" | "mar" => Some(3),
|
||||
"april" | "apr" => Some(4),
|
||||
"may" => Some(5),
|
||||
"june" | "jun" => Some(6),
|
||||
"july" | "jul" => Some(7),
|
||||
"august" | "aug" => Some(8),
|
||||
"september" | "sep" | "sept" => Some(9),
|
||||
"october" | "oct" => Some(10),
|
||||
"november" | "nov" => Some(11),
|
||||
"december" | "dec" => Some(12),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a date for display according to the given format preference.
|
||||
pub fn format_date(date: NaiveDate, format: DateFormat) -> String {
|
||||
match format {
|
||||
DateFormat::US => date.format("%B %-d, %Y").to_string(),
|
||||
DateFormat::EU => date.format("%-d %B %Y").to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a day-count delta for display, with an optional month breakdown.
|
||||
pub fn format_day_delta(days: i64) -> String {
|
||||
let abs_days = days.unsigned_abs();
|
||||
if abs_days == 0 {
|
||||
return "0 days".to_string();
|
||||
}
|
||||
if abs_days >= 30 {
|
||||
let months = abs_days / 30;
|
||||
let remaining = abs_days % 30;
|
||||
let m_label = if months == 1 { "month" } else { "months" };
|
||||
if remaining > 0 {
|
||||
let d_label = if remaining == 1 { "day" } else { "days" };
|
||||
format!(
|
||||
"{} days ({} {} {} {})",
|
||||
abs_days, months, m_label, remaining, d_label
|
||||
)
|
||||
} else {
|
||||
format!("{} days ({} {})", abs_days, months, m_label)
|
||||
}
|
||||
} else {
|
||||
let d_label = if abs_days == 1 { "day" } else { "days" };
|
||||
format!("{} {}", abs_days, d_label)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn d(y: i32, m: u32, day: u32) -> NaiveDate {
|
||||
NaiveDate::from_ymd_opt(y, m, day).unwrap()
|
||||
}
|
||||
|
||||
// -- CalendarDuration --
|
||||
|
||||
#[test]
|
||||
fn test_duration_display_mixed() {
|
||||
let dur = CalendarDuration {
|
||||
years: 1,
|
||||
months: 2,
|
||||
weeks: 3,
|
||||
days: 4,
|
||||
};
|
||||
assert_eq!(dur.to_string(), "1 year 2 months 3 weeks 4 days");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duration_display_zero() {
|
||||
assert_eq!(CalendarDuration::zero().to_string(), "0 days");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duration_total_days_approx() {
|
||||
let dur = CalendarDuration {
|
||||
years: 1,
|
||||
months: 0,
|
||||
weeks: 0,
|
||||
days: 0,
|
||||
};
|
||||
assert_eq!(dur.total_days_approx(), 365);
|
||||
}
|
||||
|
||||
// -- add_duration / sub_duration --
|
||||
|
||||
#[test]
|
||||
fn test_add_weeks_and_days() {
|
||||
// March 17 + 3 weeks 2 days = April 9
|
||||
let result = add_duration(
|
||||
d(2026, 3, 17),
|
||||
&CalendarDuration {
|
||||
years: 0,
|
||||
months: 0,
|
||||
weeks: 3,
|
||||
days: 2,
|
||||
},
|
||||
);
|
||||
assert_eq!(result, Some(d(2026, 4, 9)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_one_year() {
|
||||
let result = add_duration(
|
||||
d(2026, 3, 17),
|
||||
&CalendarDuration {
|
||||
years: 1,
|
||||
months: 0,
|
||||
weeks: 0,
|
||||
days: 0,
|
||||
},
|
||||
);
|
||||
assert_eq!(result, Some(d(2027, 3, 17)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_one_month_end_of_month_clamp() {
|
||||
// Jan 31 + 1 month = Feb 28 (non-leap)
|
||||
let result = add_duration(
|
||||
d(2026, 1, 31),
|
||||
&CalendarDuration {
|
||||
years: 0,
|
||||
months: 1,
|
||||
weeks: 0,
|
||||
days: 0,
|
||||
},
|
||||
);
|
||||
assert_eq!(result, Some(d(2026, 2, 28)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_30_days() {
|
||||
// Jan 15, 2025 - 30 days = Dec 16, 2024
|
||||
let result = sub_duration(
|
||||
d(2025, 1, 15),
|
||||
&CalendarDuration {
|
||||
years: 0,
|
||||
months: 0,
|
||||
weeks: 0,
|
||||
days: 30,
|
||||
},
|
||||
);
|
||||
assert_eq!(result, Some(d(2024, 12, 16)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leap_year_add_one_day() {
|
||||
let result = add_duration(
|
||||
d(2024, 2, 28),
|
||||
&CalendarDuration {
|
||||
years: 0,
|
||||
months: 0,
|
||||
weeks: 0,
|
||||
days: 1,
|
||||
},
|
||||
);
|
||||
assert_eq!(result, Some(d(2024, 2, 29)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_year_boundary() {
|
||||
let result = add_duration(
|
||||
d(2025, 12, 30),
|
||||
&CalendarDuration {
|
||||
years: 0,
|
||||
months: 0,
|
||||
weeks: 0,
|
||||
days: 5,
|
||||
},
|
||||
);
|
||||
assert_eq!(result, Some(d(2026, 1, 4)));
|
||||
}
|
||||
|
||||
// -- days_between --
|
||||
|
||||
#[test]
|
||||
fn test_days_between_same() {
|
||||
assert_eq!(days_between(d(2026, 3, 17), d(2026, 3, 17)), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_days_between_positive() {
|
||||
// March 12 to July 30 = 140 days
|
||||
assert_eq!(days_between(d(2026, 3, 12), d(2026, 7, 30)), 140);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_days_between_negative() {
|
||||
assert_eq!(days_between(d(2026, 3, 17), d(2026, 3, 1)), -16);
|
||||
}
|
||||
|
||||
// -- resolve_named_date --
|
||||
|
||||
#[test]
|
||||
fn test_christmas_before() {
|
||||
assert_eq!(
|
||||
resolve_named_date("Christmas", d(2026, 3, 17)),
|
||||
Some(d(2026, 12, 25))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_christmas_after() {
|
||||
assert_eq!(
|
||||
resolve_named_date("Christmas", d(2026, 12, 26)),
|
||||
Some(d(2027, 12, 25))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_years() {
|
||||
assert_eq!(
|
||||
resolve_named_date("newyear", d(2026, 3, 17)),
|
||||
Some(d(2027, 1, 1))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_halloween() {
|
||||
assert_eq!(
|
||||
resolve_named_date("halloween", d(2026, 3, 17)),
|
||||
Some(d(2026, 10, 31))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_named_date() {
|
||||
assert_eq!(resolve_named_date("festivus", d(2026, 3, 17)), None);
|
||||
}
|
||||
|
||||
// -- month_number --
|
||||
|
||||
#[test]
|
||||
fn test_month_full_name() {
|
||||
assert_eq!(month_number("January"), Some(1));
|
||||
assert_eq!(month_number("december"), Some(12));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_month_abbreviation() {
|
||||
assert_eq!(month_number("jan"), Some(1));
|
||||
assert_eq!(month_number("Sep"), Some(9));
|
||||
}
|
||||
|
||||
// -- format_date --
|
||||
|
||||
#[test]
|
||||
fn test_format_date_us() {
|
||||
assert_eq!(format_date(d(2025, 1, 15), DateFormat::US), "January 15, 2025");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_date_eu() {
|
||||
assert_eq!(format_date(d(2025, 1, 15), DateFormat::EU), "15 January 2025");
|
||||
}
|
||||
|
||||
// -- format_day_delta --
|
||||
|
||||
#[test]
|
||||
fn test_format_delta_zero() {
|
||||
assert_eq!(format_day_delta(0), "0 days");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_delta_one_day() {
|
||||
assert_eq!(format_day_delta(1), "1 day");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_delta_short() {
|
||||
assert_eq!(format_day_delta(16), "16 days");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_delta_with_months() {
|
||||
assert_eq!(format_day_delta(140), "140 days (4 months 20 days)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_delta_exact_months() {
|
||||
assert_eq!(format_day_delta(60), "60 days (2 months)");
|
||||
}
|
||||
}
|
||||
49
calcpad-engine/src/datetime/mod.rs
Normal file
49
calcpad-engine/src/datetime/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Unified date/time/timezone system for the calcpad engine.
|
||||
//!
|
||||
//! This module consolidates all temporal calculations:
|
||||
//!
|
||||
//! - **date_math**: Date arithmetic, named dates, calendar durations, formatting
|
||||
//! - **time_math**: Time arithmetic, 12/24-hour support, time ranges
|
||||
//! - **timezone**: Timezone resolution (500+ city names, abbreviations, IANA),
|
||||
//! cross-zone conversion with DST awareness
|
||||
//! - **business_days**: Business day calculations, holiday calendars, weekday resolution
|
||||
//! - **unix**: Unix timestamp <-> human-readable conversions
|
||||
//! - **relative**: Relative time expressions ("2 hours ago", "next Wednesday")
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```rust
|
||||
//! use calcpad_engine::datetime::date_math::{add_duration, CalendarDuration, DateFormat};
|
||||
//! use chrono::NaiveDate;
|
||||
//!
|
||||
//! let today = NaiveDate::from_ymd_opt(2026, 3, 17).unwrap();
|
||||
//! let dur = CalendarDuration { years: 0, months: 0, weeks: 3, days: 2 };
|
||||
//! let result = add_duration(today, &dur).unwrap();
|
||||
//! assert_eq!(result, NaiveDate::from_ymd_opt(2026, 4, 9).unwrap());
|
||||
//! ```
|
||||
|
||||
pub mod business_days;
|
||||
pub mod date_math;
|
||||
pub mod relative;
|
||||
pub mod time_math;
|
||||
pub mod timezone;
|
||||
pub mod unix;
|
||||
|
||||
// Re-export core types for convenience.
|
||||
pub use business_days::{add_business_days, sub_business_days, us_holidays, BusinessDayConfig};
|
||||
pub use date_math::{
|
||||
add_duration, days_between, format_date, format_day_delta, month_number,
|
||||
resolve_named_date, sub_duration, CalendarDuration, DateFormat,
|
||||
};
|
||||
pub use relative::{
|
||||
eval_day_of_week_ref, eval_named_relative_day, eval_relative_offset, RelativeDirection,
|
||||
RelativeResult, RelativeUnit,
|
||||
};
|
||||
pub use time_math::{
|
||||
add_time_duration, duration_between, format_time, format_time_result, sub_time_duration,
|
||||
TimeDuration, TimeFormat, TimeResult,
|
||||
};
|
||||
pub use timezone::{
|
||||
convert_time, current_time_in, format_zoned_time, resolve_timezone, ZonedTimeResult,
|
||||
};
|
||||
pub use unix::{from_timestamp_in_tz, from_timestamp_utc, to_timestamp_in_tz, to_timestamp_utc};
|
||||
320
calcpad-engine/src/datetime/relative.rs
Normal file
320
calcpad-engine/src/datetime/relative.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Relative time expressions: "2 hours ago", "in 3 days", "next Wednesday",
|
||||
//! "last Friday", "tomorrow at 3pm".
|
||||
|
||||
use chrono::{Datelike, Duration, NaiveDate, NaiveTime, Timelike, Weekday};
|
||||
|
||||
/// The direction of a relative expression.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RelativeDirection {
|
||||
/// In the past: "2 hours ago", "last Monday"
|
||||
Past,
|
||||
/// In the future: "in 3 days", "next Friday"
|
||||
Future,
|
||||
}
|
||||
|
||||
/// A time unit for relative offsets.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RelativeUnit {
|
||||
Minutes,
|
||||
Hours,
|
||||
Days,
|
||||
Weeks,
|
||||
Months,
|
||||
}
|
||||
|
||||
/// The result of evaluating a relative time expression.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RelativeResult {
|
||||
pub date: NaiveDate,
|
||||
/// Optional time-of-day (present for time-level expressions or "tomorrow at 3pm").
|
||||
pub time: Option<NaiveTime>,
|
||||
}
|
||||
|
||||
/// Evaluate a relative offset: "3 days ago", "in 2 weeks", etc.
|
||||
///
|
||||
/// For sub-day units (Hours, Minutes), the `now_time` is used as the base
|
||||
/// and the result will include a time component. For day-and-above units
|
||||
/// only the date is adjusted.
|
||||
pub fn eval_relative_offset(
|
||||
amount: i64,
|
||||
unit: RelativeUnit,
|
||||
direction: RelativeDirection,
|
||||
now_date: NaiveDate,
|
||||
now_time: NaiveTime,
|
||||
) -> Option<RelativeResult> {
|
||||
let signed = match direction {
|
||||
RelativeDirection::Past => -amount,
|
||||
RelativeDirection::Future => amount,
|
||||
};
|
||||
|
||||
match unit {
|
||||
RelativeUnit::Minutes => {
|
||||
let total_minutes =
|
||||
now_time.hour() as i64 * 60 + now_time.minute() as i64 + signed;
|
||||
let day_offset = total_minutes.div_euclid(24 * 60);
|
||||
let normalized = total_minutes.rem_euclid(24 * 60);
|
||||
let hour = (normalized / 60) as u32;
|
||||
let minute = (normalized % 60) as u32;
|
||||
let date = now_date.checked_add_signed(Duration::days(day_offset))?;
|
||||
Some(RelativeResult {
|
||||
date,
|
||||
time: NaiveTime::from_hms_opt(hour, minute, 0),
|
||||
})
|
||||
}
|
||||
RelativeUnit::Hours => {
|
||||
let total_minutes =
|
||||
now_time.hour() as i64 * 60 + now_time.minute() as i64 + signed * 60;
|
||||
let day_offset = total_minutes.div_euclid(24 * 60);
|
||||
let normalized = total_minutes.rem_euclid(24 * 60);
|
||||
let hour = (normalized / 60) as u32;
|
||||
let minute = (normalized % 60) as u32;
|
||||
let date = now_date.checked_add_signed(Duration::days(day_offset))?;
|
||||
Some(RelativeResult {
|
||||
date,
|
||||
time: NaiveTime::from_hms_opt(hour, minute, 0),
|
||||
})
|
||||
}
|
||||
RelativeUnit::Days => {
|
||||
let date = now_date.checked_add_signed(Duration::days(signed))?;
|
||||
Some(RelativeResult { date, time: None })
|
||||
}
|
||||
RelativeUnit::Weeks => {
|
||||
let date = now_date.checked_add_signed(Duration::weeks(signed))?;
|
||||
Some(RelativeResult { date, time: None })
|
||||
}
|
||||
RelativeUnit::Months => {
|
||||
use chrono::Months;
|
||||
let date = if signed > 0 {
|
||||
now_date.checked_add_months(Months::new(signed as u32))?
|
||||
} else {
|
||||
now_date.checked_sub_months(Months::new((-signed) as u32))?
|
||||
};
|
||||
Some(RelativeResult { date, time: None })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate "next <weekday>" or "last <weekday>".
|
||||
///
|
||||
/// - `next`: finds the **next** occurrence of the weekday strictly after `reference`.
|
||||
/// - `last`: finds the most recent **past** occurrence strictly before `reference`.
|
||||
///
|
||||
/// Optionally takes a time-of-day (e.g. "next Wednesday at 3pm").
|
||||
pub fn eval_day_of_week_ref(
|
||||
weekday: Weekday,
|
||||
direction: RelativeDirection,
|
||||
reference: NaiveDate,
|
||||
time_of_day: Option<NaiveTime>,
|
||||
) -> Option<RelativeResult> {
|
||||
let current_wd = reference.weekday();
|
||||
let date = match direction {
|
||||
RelativeDirection::Future => {
|
||||
let days_ahead = (weekday.num_days_from_monday() as i32
|
||||
- current_wd.num_days_from_monday() as i32
|
||||
+ 7)
|
||||
% 7;
|
||||
let days_ahead = if days_ahead == 0 { 7 } else { days_ahead };
|
||||
reference.checked_add_signed(Duration::days(days_ahead as i64))?
|
||||
}
|
||||
RelativeDirection::Past => {
|
||||
let days_back = (current_wd.num_days_from_monday() as i32
|
||||
- weekday.num_days_from_monday() as i32
|
||||
+ 7)
|
||||
% 7;
|
||||
let days_back = if days_back == 0 { 7 } else { days_back };
|
||||
reference.checked_sub_signed(Duration::days(days_back as i64))?
|
||||
}
|
||||
};
|
||||
|
||||
Some(RelativeResult {
|
||||
date,
|
||||
time: time_of_day,
|
||||
})
|
||||
}
|
||||
|
||||
/// Evaluate "tomorrow" / "yesterday" with an optional time-of-day.
|
||||
pub fn eval_named_relative_day(
|
||||
offset_days: i64,
|
||||
reference: NaiveDate,
|
||||
time_of_day: Option<NaiveTime>,
|
||||
) -> Option<RelativeResult> {
|
||||
let date = reference.checked_add_signed(Duration::days(offset_days))?;
|
||||
Some(RelativeResult {
|
||||
date,
|
||||
time: time_of_day,
|
||||
})
|
||||
}
|
||||
|
||||
/// Format a `RelativeResult` for display.
|
||||
pub fn format_relative_result(
|
||||
result: &RelativeResult,
|
||||
date_format: crate::datetime::date_math::DateFormat,
|
||||
) -> String {
|
||||
let date_str = result
|
||||
.date
|
||||
.format("%A, %B %-d, %Y")
|
||||
.to_string();
|
||||
match &result.time {
|
||||
Some(t) => {
|
||||
let (h12, is_pm) = crate::datetime::time_math::to_12h(t.hour());
|
||||
let ampm = if is_pm { "PM" } else { "AM" };
|
||||
format!("{} at {}:{:02} {}", date_str, h12, t.minute(), ampm)
|
||||
}
|
||||
None => match date_format {
|
||||
crate::datetime::date_math::DateFormat::US => {
|
||||
crate::datetime::date_math::format_date(result.date, date_format)
|
||||
}
|
||||
crate::datetime::date_math::DateFormat::EU => {
|
||||
crate::datetime::date_math::format_date(result.date, date_format)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::NaiveTime;
|
||||
|
||||
fn d(y: i32, m: u32, day: u32) -> NaiveDate {
|
||||
NaiveDate::from_ymd_opt(y, m, day).unwrap()
|
||||
}
|
||||
|
||||
fn t(h: u32, m: u32) -> NaiveTime {
|
||||
NaiveTime::from_hms_opt(h, m, 0).unwrap()
|
||||
}
|
||||
|
||||
// -- eval_relative_offset --
|
||||
|
||||
#[test]
|
||||
fn test_3_days_ago() {
|
||||
let r = eval_relative_offset(3, RelativeUnit::Days, RelativeDirection::Past, d(2026, 3, 17), t(10, 0)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 14));
|
||||
assert_eq!(r.time, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_in_2_weeks() {
|
||||
let r = eval_relative_offset(2, RelativeUnit::Weeks, RelativeDirection::Future, d(2026, 3, 17), t(10, 0)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 31));
|
||||
assert_eq!(r.time, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_2_hours_ago() {
|
||||
let r = eval_relative_offset(2, RelativeUnit::Hours, RelativeDirection::Past, d(2026, 3, 17), t(10, 30)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 17));
|
||||
assert_eq!(r.time, Some(t(8, 30)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_5_hours_from_now_crossing_midnight() {
|
||||
// 10:00 PM + 5 hours = 3:00 AM next day
|
||||
let r = eval_relative_offset(5, RelativeUnit::Hours, RelativeDirection::Future, d(2026, 3, 17), t(22, 0)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 18));
|
||||
assert_eq!(r.time, Some(t(3, 0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_45_minutes_ago() {
|
||||
let r = eval_relative_offset(45, RelativeUnit::Minutes, RelativeDirection::Past, d(2026, 3, 17), t(10, 30)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 17));
|
||||
assert_eq!(r.time, Some(t(9, 45)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_in_3_months() {
|
||||
let r = eval_relative_offset(3, RelativeUnit::Months, RelativeDirection::Future, d(2026, 3, 17), t(10, 0)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 6, 17));
|
||||
assert_eq!(r.time, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_2_months_ago() {
|
||||
let r = eval_relative_offset(2, RelativeUnit::Months, RelativeDirection::Past, d(2026, 3, 17), t(10, 0)).unwrap();
|
||||
assert_eq!(r.date, d(2026, 1, 17));
|
||||
}
|
||||
|
||||
// -- eval_day_of_week_ref --
|
||||
|
||||
#[test]
|
||||
fn test_next_wednesday() {
|
||||
// March 17 2026 is Tuesday, next Wednesday = March 18
|
||||
let r = eval_day_of_week_ref(Weekday::Wed, RelativeDirection::Future, d(2026, 3, 17), None).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 18));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_tuesday_from_tuesday() {
|
||||
// Same day → next week
|
||||
let r = eval_day_of_week_ref(Weekday::Tue, RelativeDirection::Future, d(2026, 3, 17), None).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 24));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_last_monday() {
|
||||
// March 17 Tue, last Monday = March 16
|
||||
let r = eval_day_of_week_ref(Weekday::Mon, RelativeDirection::Past, d(2026, 3, 17), None).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_last_tuesday_from_tuesday() {
|
||||
// Same day → previous week
|
||||
let r = eval_day_of_week_ref(Weekday::Tue, RelativeDirection::Past, d(2026, 3, 17), None).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_friday_at_3pm() {
|
||||
let r = eval_day_of_week_ref(Weekday::Fri, RelativeDirection::Future, d(2026, 3, 17), Some(t(15, 0))).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 20));
|
||||
assert_eq!(r.time, Some(t(15, 0)));
|
||||
}
|
||||
|
||||
// -- eval_named_relative_day --
|
||||
|
||||
#[test]
|
||||
fn test_tomorrow() {
|
||||
let r = eval_named_relative_day(1, d(2026, 3, 17), None).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 18));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yesterday() {
|
||||
let r = eval_named_relative_day(-1, d(2026, 3, 17), None).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tomorrow_at_3pm() {
|
||||
let r = eval_named_relative_day(1, d(2026, 3, 17), Some(t(15, 0))).unwrap();
|
||||
assert_eq!(r.date, d(2026, 3, 18));
|
||||
assert_eq!(r.time, Some(t(15, 0)));
|
||||
}
|
||||
|
||||
// -- format_relative_result --
|
||||
|
||||
#[test]
|
||||
fn test_format_date_only() {
|
||||
let r = RelativeResult {
|
||||
date: d(2026, 3, 20),
|
||||
time: None,
|
||||
};
|
||||
assert_eq!(
|
||||
format_relative_result(&r, crate::datetime::date_math::DateFormat::US),
|
||||
"March 20, 2026"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_with_time() {
|
||||
let r = RelativeResult {
|
||||
date: d(2026, 3, 20),
|
||||
time: Some(t(15, 0)),
|
||||
};
|
||||
let s = format_relative_result(&r, crate::datetime::date_math::DateFormat::US);
|
||||
assert!(s.contains("3:00 PM"));
|
||||
assert!(s.contains("2026"));
|
||||
}
|
||||
}
|
||||
334
calcpad-engine/src/datetime/time_math.rs
Normal file
334
calcpad-engine/src/datetime/time_math.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
//! Time arithmetic: 3:35 AM + 9h20m, duration between times, 12/24-hour display.
|
||||
|
||||
use chrono::{NaiveTime, Timelike};
|
||||
|
||||
/// User preference for time display.
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TimeFormat {
|
||||
/// 12-hour with AM/PM (e.g., 3:35 PM)
|
||||
#[default]
|
||||
TwelveHour,
|
||||
/// 24-hour (e.g., 15:35)
|
||||
TwentyFourHour,
|
||||
}
|
||||
|
||||
/// A time-only duration in hours and minutes.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TimeDuration {
|
||||
pub hours: i64,
|
||||
pub minutes: i64,
|
||||
}
|
||||
|
||||
impl TimeDuration {
|
||||
pub fn zero() -> Self {
|
||||
Self {
|
||||
hours: 0,
|
||||
minutes: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total signed minutes.
|
||||
pub fn total_minutes(&self) -> i64 {
|
||||
self.hours * 60 + self.minutes
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TimeDuration {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut parts = Vec::new();
|
||||
if self.hours != 0 {
|
||||
let label = if self.hours == 1 { "hour" } else { "hours" };
|
||||
parts.push(format!("{} {}", self.hours, label));
|
||||
}
|
||||
if self.minutes != 0 {
|
||||
let label = if self.minutes == 1 {
|
||||
"minute"
|
||||
} else {
|
||||
"minutes"
|
||||
};
|
||||
parts.push(format!("{} {}", self.minutes, label));
|
||||
}
|
||||
if parts.is_empty() {
|
||||
write!(f, "0 minutes")
|
||||
} else {
|
||||
write!(f, "{}", parts.join(" "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of adding/subtracting a duration to/from a time.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TimeResult {
|
||||
pub time: NaiveTime,
|
||||
/// How many days the result has rolled past midnight.
|
||||
/// +1 = next day, -1 = previous day, 0 = same day.
|
||||
pub day_offset: i32,
|
||||
}
|
||||
|
||||
/// Add hours and minutes to a time, returning the result and any day overflow.
|
||||
pub fn add_time_duration(time: NaiveTime, hours: i64, minutes: i64) -> TimeResult {
|
||||
let time_minutes = time.hour() as i64 * 60 + time.minute() as i64;
|
||||
let total_add = hours * 60 + minutes;
|
||||
let result_minutes = time_minutes + total_add;
|
||||
|
||||
let day_offset = result_minutes.div_euclid(24 * 60) as i32;
|
||||
let normalized = result_minutes.rem_euclid(24 * 60);
|
||||
|
||||
let hour = (normalized / 60) as u32;
|
||||
let minute = (normalized % 60) as u32;
|
||||
|
||||
TimeResult {
|
||||
time: NaiveTime::from_hms_opt(hour, minute, 0).unwrap_or(NaiveTime::from_hms_opt(0, 0, 0).unwrap()),
|
||||
day_offset,
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtract hours and minutes from a time.
|
||||
pub fn sub_time_duration(time: NaiveTime, hours: i64, minutes: i64) -> TimeResult {
|
||||
add_time_duration(time, -hours, -minutes)
|
||||
}
|
||||
|
||||
/// Calculate the duration between two times. If `to` < `from`, assumes midnight
|
||||
/// crossing (i.e. `to` is on the next day).
|
||||
pub fn duration_between(from: NaiveTime, to: NaiveTime) -> TimeDuration {
|
||||
let m1 = from.hour() as i64 * 60 + from.minute() as i64;
|
||||
let m2 = to.hour() as i64 * 60 + to.minute() as i64;
|
||||
|
||||
let diff = if m2 >= m1 { m2 - m1 } else { (24 * 60 - m1) + m2 };
|
||||
|
||||
TimeDuration {
|
||||
hours: diff / 60,
|
||||
minutes: diff % 60,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a 24-hour value to 12-hour. Returns `(hour_12, is_pm)`.
|
||||
pub fn to_12h(hour24: u32) -> (u32, bool) {
|
||||
match hour24 {
|
||||
0 => (12, false),
|
||||
1..=11 => (hour24, false),
|
||||
12 => (12, true),
|
||||
13..=23 => (hour24 - 12, true),
|
||||
_ => (hour24, false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an AM/PM indicator. Returns `Some(is_pm)` or `None`.
|
||||
pub fn parse_ampm(word: &str) -> Option<bool> {
|
||||
match word.to_lowercase().as_str() {
|
||||
"am" | "a" => Some(false),
|
||||
"pm" | "p" => Some(true),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert 12-hour time to 24-hour.
|
||||
pub fn to_24h(hour12: u32, is_pm: bool) -> u32 {
|
||||
if is_pm {
|
||||
if hour12 == 12 {
|
||||
12
|
||||
} else {
|
||||
hour12 + 12
|
||||
}
|
||||
} else if hour12 == 12 {
|
||||
0
|
||||
} else {
|
||||
hour12
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a `NaiveTime` according to the user's preference.
|
||||
pub fn format_time(time: NaiveTime, format: TimeFormat) -> String {
|
||||
match format {
|
||||
TimeFormat::TwelveHour => {
|
||||
let (h12, is_pm) = to_12h(time.hour());
|
||||
let ampm = if is_pm { "PM" } else { "AM" };
|
||||
format!("{}:{:02} {}", h12, time.minute(), ampm)
|
||||
}
|
||||
TimeFormat::TwentyFourHour => {
|
||||
format!("{}:{:02}", time.hour(), time.minute())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a `TimeResult` with optional day-offset annotation.
|
||||
pub fn format_time_result(tr: &TimeResult, format: TimeFormat) -> String {
|
||||
let time_str = format_time(tr.time, format);
|
||||
if tr.day_offset > 0 {
|
||||
format!("{} (next day)", time_str)
|
||||
} else if tr.day_offset < 0 {
|
||||
format!("{} (previous day)", time_str)
|
||||
} else {
|
||||
time_str
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn t(h: u32, m: u32) -> NaiveTime {
|
||||
NaiveTime::from_hms_opt(h, m, 0).unwrap()
|
||||
}
|
||||
|
||||
// -- add / sub --
|
||||
|
||||
#[test]
|
||||
fn test_add_no_rollover() {
|
||||
// 3:35 AM + 9h20m = 12:55 PM
|
||||
let r = add_time_duration(t(3, 35), 9, 20);
|
||||
assert_eq!(r.time, t(12, 55));
|
||||
assert_eq!(r.day_offset, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_with_rollover() {
|
||||
// 3:35 PM (15:35) + 9h20m = 0:55 AM next day
|
||||
let r = add_time_duration(t(15, 35), 9, 20);
|
||||
assert_eq!(r.time, t(0, 55));
|
||||
assert_eq!(r.day_offset, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_no_rollover() {
|
||||
// 14:30 - 2h45m = 11:45
|
||||
let r = sub_time_duration(t(14, 30), 2, 45);
|
||||
assert_eq!(r.time, t(11, 45));
|
||||
assert_eq!(r.day_offset, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_past_midnight() {
|
||||
// 1:00 AM - 3h = 10:00 PM previous day
|
||||
let r = sub_time_duration(t(1, 0), 3, 0);
|
||||
assert_eq!(r.time, t(22, 0));
|
||||
assert_eq!(r.day_offset, -1);
|
||||
}
|
||||
|
||||
// -- duration_between --
|
||||
|
||||
#[test]
|
||||
fn test_duration_workday() {
|
||||
// 9:00 AM to 5:30 PM = 8h30m
|
||||
let d = duration_between(t(9, 0), t(17, 30));
|
||||
assert_eq!(d.hours, 8);
|
||||
assert_eq!(d.minutes, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duration_midnight_crossing() {
|
||||
// 11:00 PM to 2:00 AM = 3h
|
||||
let d = duration_between(t(23, 0), t(2, 0));
|
||||
assert_eq!(d.hours, 3);
|
||||
assert_eq!(d.minutes, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duration_same_time() {
|
||||
let d = duration_between(t(9, 0), t(9, 0));
|
||||
assert_eq!(d.hours, 0);
|
||||
assert_eq!(d.minutes, 0);
|
||||
}
|
||||
|
||||
// -- 12h/24h conversion --
|
||||
|
||||
#[test]
|
||||
fn test_to_12h_midnight() {
|
||||
assert_eq!(to_12h(0), (12, false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_12h_noon() {
|
||||
assert_eq!(to_12h(12), (12, true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_12h_afternoon() {
|
||||
assert_eq!(to_12h(15), (3, true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_24h_am() {
|
||||
assert_eq!(to_24h(3, false), 3);
|
||||
assert_eq!(to_24h(12, false), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_24h_pm() {
|
||||
assert_eq!(to_24h(3, true), 15);
|
||||
assert_eq!(to_24h(12, true), 12);
|
||||
}
|
||||
|
||||
// -- format --
|
||||
|
||||
#[test]
|
||||
fn test_format_12h() {
|
||||
assert_eq!(format_time(t(12, 55), TimeFormat::TwelveHour), "12:55 PM");
|
||||
assert_eq!(format_time(t(0, 0), TimeFormat::TwelveHour), "12:00 AM");
|
||||
assert_eq!(format_time(t(15, 35), TimeFormat::TwelveHour), "3:35 PM");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_24h() {
|
||||
assert_eq!(format_time(t(11, 45), TimeFormat::TwentyFourHour), "11:45");
|
||||
assert_eq!(format_time(t(0, 0), TimeFormat::TwentyFourHour), "0:00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_time_result_next_day() {
|
||||
let tr = TimeResult {
|
||||
time: t(0, 55),
|
||||
day_offset: 1,
|
||||
};
|
||||
assert_eq!(
|
||||
format_time_result(&tr, TimeFormat::TwelveHour),
|
||||
"12:55 AM (next day)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_time_result_prev_day() {
|
||||
let tr = TimeResult {
|
||||
time: t(22, 0),
|
||||
day_offset: -1,
|
||||
};
|
||||
assert_eq!(
|
||||
format_time_result(&tr, TimeFormat::TwelveHour),
|
||||
"10:00 PM (previous day)"
|
||||
);
|
||||
}
|
||||
|
||||
// -- TimeDuration display --
|
||||
|
||||
#[test]
|
||||
fn test_time_duration_display() {
|
||||
let d = TimeDuration {
|
||||
hours: 8,
|
||||
minutes: 30,
|
||||
};
|
||||
assert_eq!(d.to_string(), "8 hours 30 minutes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_duration_display_zero() {
|
||||
assert_eq!(TimeDuration::zero().to_string(), "0 minutes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_duration_display_hours_only() {
|
||||
let d = TimeDuration {
|
||||
hours: 3,
|
||||
minutes: 0,
|
||||
};
|
||||
assert_eq!(d.to_string(), "3 hours");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_duration_display_singular() {
|
||||
let d = TimeDuration {
|
||||
hours: 1,
|
||||
minutes: 1,
|
||||
};
|
||||
assert_eq!(d.to_string(), "1 hour 1 minute");
|
||||
}
|
||||
}
|
||||
648
calcpad-engine/src/datetime/timezone.rs
Normal file
648
calcpad-engine/src/datetime/timezone.rs
Normal file
@@ -0,0 +1,648 @@
|
||||
//! Timezone conversions: city name / abbreviation lookup to IANA timezone,
|
||||
//! time conversion between zones with DST awareness via `chrono-tz`.
|
||||
|
||||
use chrono::{NaiveDate, NaiveTime, TimeZone, Timelike};
|
||||
use chrono_tz::Tz;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// The result of converting a time between timezones.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ZonedTimeResult {
|
||||
/// The hour in 12-hour format (1-12).
|
||||
pub hour12: u32,
|
||||
/// The minute.
|
||||
pub minute: u32,
|
||||
/// Whether this is PM.
|
||||
pub is_pm: bool,
|
||||
/// The timezone abbreviation in the target zone (e.g. "EDT", "JST").
|
||||
pub tz_abbr: String,
|
||||
/// The date in the target timezone.
|
||||
pub date: NaiveDate,
|
||||
/// Whether the date differs from the source date (date boundary crossing).
|
||||
pub date_changed: bool,
|
||||
}
|
||||
|
||||
/// Resolve a timezone string to a `chrono_tz::Tz`.
|
||||
///
|
||||
/// Accepts:
|
||||
/// - City names: "Tokyo", "New York", "Los Angeles"
|
||||
/// - Abbreviations: "EST", "PST", "CET", "JST"
|
||||
/// - Disambiguation: "Portland, ME" vs "Portland, OR"
|
||||
/// - Country names: "Japan", "India", "UK"
|
||||
/// - IANA identifiers: "America/New_York"
|
||||
pub fn resolve_timezone(name: &str) -> Option<Tz> {
|
||||
let normalized = name.trim().to_lowercase();
|
||||
|
||||
// Try abbreviation first (exact match)
|
||||
if let Some(tz) = abbreviation_map().get(normalized.as_str()) {
|
||||
return Some(*tz);
|
||||
}
|
||||
|
||||
// Try city name lookup
|
||||
if let Some(tz) = city_map().get(normalized.as_str()) {
|
||||
return Some(*tz);
|
||||
}
|
||||
|
||||
// Try direct IANA parse
|
||||
if let Ok(tz) = normalized.parse::<Tz>() {
|
||||
return Some(tz);
|
||||
}
|
||||
|
||||
// Try replacing spaces with underscores for IANA
|
||||
let iana_attempt = name.trim().replace(' ', "_");
|
||||
if let Ok(tz) = iana_attempt.parse::<Tz>() {
|
||||
return Some(tz);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Convert a time from one timezone to another.
|
||||
///
|
||||
/// - `hour24`, `minute`: the source time in 24-hour format
|
||||
/// - `source_date`: the date in the source timezone
|
||||
/// - `source_tz`, `target_tz`: resolved timezone objects
|
||||
///
|
||||
/// Returns `None` if the local time is ambiguous or invalid (e.g. during DST gap).
|
||||
pub fn convert_time(
|
||||
hour24: u32,
|
||||
minute: u32,
|
||||
source_date: NaiveDate,
|
||||
source_tz: Tz,
|
||||
target_tz: Tz,
|
||||
) -> Option<ZonedTimeResult> {
|
||||
let source_time = NaiveTime::from_hms_opt(hour24, minute, 0)?;
|
||||
let source_naive = source_date.and_time(source_time);
|
||||
|
||||
let source_dt = source_tz
|
||||
.from_local_datetime(&source_naive)
|
||||
.single()?;
|
||||
|
||||
let target_dt = source_dt.with_timezone(&target_tz);
|
||||
let target_date = target_dt.date_naive();
|
||||
let target_hour = target_dt.hour();
|
||||
let target_minute = target_dt.minute();
|
||||
let tz_abbr = target_dt.format("%Z").to_string();
|
||||
|
||||
let (h12, is_pm) = crate::datetime::time_math::to_12h(target_hour);
|
||||
|
||||
Some(ZonedTimeResult {
|
||||
hour12: h12,
|
||||
minute: target_minute,
|
||||
is_pm,
|
||||
tz_abbr,
|
||||
date: target_date,
|
||||
date_changed: target_date != source_date,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current time in a given timezone.
|
||||
pub fn current_time_in(
|
||||
tz: Tz,
|
||||
now_utc: chrono::DateTime<chrono::Utc>,
|
||||
) -> ZonedTimeResult {
|
||||
let in_tz = now_utc.with_timezone(&tz);
|
||||
let date = in_tz.date_naive();
|
||||
let hour = in_tz.hour();
|
||||
let minute = in_tz.minute();
|
||||
let tz_abbr = in_tz.format("%Z").to_string();
|
||||
let (h12, is_pm) = crate::datetime::time_math::to_12h(hour);
|
||||
|
||||
ZonedTimeResult {
|
||||
hour12: h12,
|
||||
minute,
|
||||
is_pm,
|
||||
tz_abbr,
|
||||
date,
|
||||
date_changed: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a `ZonedTimeResult` for display.
|
||||
pub fn format_zoned_time(
|
||||
result: &ZonedTimeResult,
|
||||
date_format: crate::datetime::date_math::DateFormat,
|
||||
) -> String {
|
||||
let ampm = if result.is_pm { "PM" } else { "AM" };
|
||||
let time_str = format!("{}:{:02} {} {}", result.hour12, result.minute, ampm, result.tz_abbr);
|
||||
if result.date_changed {
|
||||
let date_str = crate::datetime::date_math::format_date(result.date, date_format);
|
||||
format!("{} ({})", time_str, date_str)
|
||||
} else {
|
||||
time_str
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of city aliases in the database.
|
||||
pub fn city_alias_count() -> usize {
|
||||
city_map().len()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal lookup tables
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn abbreviation_map() -> &'static HashMap<&'static str, Tz> {
|
||||
static MAP: OnceLock<HashMap<&'static str, Tz>> = OnceLock::new();
|
||||
MAP.get_or_init(|| {
|
||||
let mut m = HashMap::new();
|
||||
// North America
|
||||
m.insert("est", Tz::America__New_York);
|
||||
m.insert("edt", Tz::America__New_York);
|
||||
m.insert("cst", Tz::America__Chicago);
|
||||
m.insert("cdt", Tz::America__Chicago);
|
||||
m.insert("mst", Tz::America__Denver);
|
||||
m.insert("mdt", Tz::America__Denver);
|
||||
m.insert("pst", Tz::America__Los_Angeles);
|
||||
m.insert("pdt", Tz::America__Los_Angeles);
|
||||
m.insert("akst", Tz::America__Anchorage);
|
||||
m.insert("akdt", Tz::America__Anchorage);
|
||||
m.insert("hst", Tz::Pacific__Honolulu);
|
||||
m.insert("ast", Tz::America__Halifax);
|
||||
m.insert("adt", Tz::America__Halifax);
|
||||
m.insert("nst", Tz::America__St_Johns);
|
||||
m.insert("ndt", Tz::America__St_Johns);
|
||||
// Europe
|
||||
m.insert("gmt", Tz::Europe__London);
|
||||
m.insert("bst", Tz::Europe__London);
|
||||
m.insert("utc", Tz::UTC);
|
||||
m.insert("wet", Tz::Europe__Lisbon);
|
||||
m.insert("west", Tz::Europe__Lisbon);
|
||||
m.insert("cet", Tz::Europe__Paris);
|
||||
m.insert("cest", Tz::Europe__Paris);
|
||||
m.insert("eet", Tz::Europe__Bucharest);
|
||||
m.insert("eest", Tz::Europe__Bucharest);
|
||||
m.insert("msk", Tz::Europe__Moscow);
|
||||
// Asia
|
||||
m.insert("ist", Tz::Asia__Kolkata);
|
||||
m.insert("jst", Tz::Asia__Tokyo);
|
||||
m.insert("kst", Tz::Asia__Seoul);
|
||||
m.insert("hkt", Tz::Asia__Hong_Kong);
|
||||
m.insert("sgt", Tz::Asia__Singapore);
|
||||
m.insert("pht", Tz::Asia__Manila);
|
||||
m.insert("wib", Tz::Asia__Jakarta);
|
||||
m.insert("wit", Tz::Asia__Jayapura);
|
||||
m.insert("wita", Tz::Asia__Makassar);
|
||||
m.insert("ict", Tz::Asia__Bangkok);
|
||||
m.insert("bdt", Tz::Asia__Dhaka);
|
||||
m.insert("pkt", Tz::Asia__Karachi);
|
||||
m.insert("aft", Tz::Asia__Kabul);
|
||||
m.insert("irst", Tz::Asia__Tehran);
|
||||
m.insert("gst", Tz::Asia__Dubai);
|
||||
m.insert("trt", Tz::Europe__Istanbul);
|
||||
// Oceania
|
||||
m.insert("aest", Tz::Australia__Sydney);
|
||||
m.insert("aedt", Tz::Australia__Sydney);
|
||||
m.insert("acst", Tz::Australia__Adelaide);
|
||||
m.insert("acdt", Tz::Australia__Adelaide);
|
||||
m.insert("awst", Tz::Australia__Perth);
|
||||
m.insert("nzst", Tz::Pacific__Auckland);
|
||||
m.insert("nzdt", Tz::Pacific__Auckland);
|
||||
// South America
|
||||
m.insert("brt", Tz::America__Sao_Paulo);
|
||||
m.insert("art", Tz::America__Argentina__Buenos_Aires);
|
||||
m.insert("clt", Tz::America__Santiago);
|
||||
m.insert("pet", Tz::America__Lima);
|
||||
m.insert("cot", Tz::America__Bogota);
|
||||
m.insert("vet", Tz::America__Caracas);
|
||||
// Africa
|
||||
m.insert("cat", Tz::Africa__Maputo);
|
||||
m.insert("eat", Tz::Africa__Nairobi);
|
||||
m.insert("wat", Tz::Africa__Lagos);
|
||||
m.insert("sast", Tz::Africa__Johannesburg);
|
||||
m
|
||||
})
|
||||
}
|
||||
|
||||
fn city_map() -> &'static HashMap<&'static str, Tz> {
|
||||
static MAP: OnceLock<HashMap<&'static str, Tz>> = OnceLock::new();
|
||||
MAP.get_or_init(|| {
|
||||
let mut m = HashMap::new();
|
||||
|
||||
// ===== NORTH AMERICA =====
|
||||
m.insert("new york", Tz::America__New_York);
|
||||
m.insert("new york city", Tz::America__New_York);
|
||||
m.insert("nyc", Tz::America__New_York);
|
||||
m.insert("manhattan", Tz::America__New_York);
|
||||
m.insert("brooklyn", Tz::America__New_York);
|
||||
m.insert("queens", Tz::America__New_York);
|
||||
m.insert("bronx", Tz::America__New_York);
|
||||
m.insert("los angeles", Tz::America__Los_Angeles);
|
||||
m.insert("la", Tz::America__Los_Angeles);
|
||||
m.insert("hollywood", Tz::America__Los_Angeles);
|
||||
m.insert("chicago", Tz::America__Chicago);
|
||||
m.insert("houston", Tz::America__Chicago);
|
||||
m.insert("phoenix", Tz::America__Phoenix);
|
||||
m.insert("philadelphia", Tz::America__New_York);
|
||||
m.insert("san antonio", Tz::America__Chicago);
|
||||
m.insert("san diego", Tz::America__Los_Angeles);
|
||||
m.insert("dallas", Tz::America__Chicago);
|
||||
m.insert("san jose", Tz::America__Los_Angeles);
|
||||
m.insert("austin", Tz::America__Chicago);
|
||||
m.insert("jacksonville", Tz::America__New_York);
|
||||
m.insert("fort worth", Tz::America__Chicago);
|
||||
m.insert("columbus", Tz::America__New_York);
|
||||
m.insert("charlotte", Tz::America__New_York);
|
||||
m.insert("san francisco", Tz::America__Los_Angeles);
|
||||
m.insert("sf", Tz::America__Los_Angeles);
|
||||
m.insert("indianapolis", Tz::America__Indiana__Indianapolis);
|
||||
m.insert("seattle", Tz::America__Los_Angeles);
|
||||
m.insert("denver", Tz::America__Denver);
|
||||
m.insert("washington", Tz::America__New_York);
|
||||
m.insert("washington dc", Tz::America__New_York);
|
||||
m.insert("dc", Tz::America__New_York);
|
||||
m.insert("nashville", Tz::America__Chicago);
|
||||
m.insert("oklahoma city", Tz::America__Chicago);
|
||||
m.insert("el paso", Tz::America__Denver);
|
||||
m.insert("boston", Tz::America__New_York);
|
||||
m.insert("portland", Tz::America__Los_Angeles);
|
||||
m.insert("portland, or", Tz::America__Los_Angeles);
|
||||
m.insert("portland, oregon", Tz::America__Los_Angeles);
|
||||
m.insert("portland, me", Tz::America__New_York);
|
||||
m.insert("portland, maine", Tz::America__New_York);
|
||||
m.insert("las vegas", Tz::America__Los_Angeles);
|
||||
m.insert("vegas", Tz::America__Los_Angeles);
|
||||
m.insert("memphis", Tz::America__Chicago);
|
||||
m.insert("louisville", Tz::America__Kentucky__Louisville);
|
||||
m.insert("baltimore", Tz::America__New_York);
|
||||
m.insert("milwaukee", Tz::America__Chicago);
|
||||
m.insert("albuquerque", Tz::America__Denver);
|
||||
m.insert("tucson", Tz::America__Phoenix);
|
||||
m.insert("fresno", Tz::America__Los_Angeles);
|
||||
m.insert("sacramento", Tz::America__Los_Angeles);
|
||||
m.insert("mesa", Tz::America__Phoenix);
|
||||
m.insert("atlanta", Tz::America__New_York);
|
||||
m.insert("kansas city", Tz::America__Chicago);
|
||||
m.insert("colorado springs", Tz::America__Denver);
|
||||
m.insert("omaha", Tz::America__Chicago);
|
||||
m.insert("raleigh", Tz::America__New_York);
|
||||
m.insert("miami", Tz::America__New_York);
|
||||
m.insert("tampa", Tz::America__New_York);
|
||||
m.insert("orlando", Tz::America__New_York);
|
||||
m.insert("cleveland", Tz::America__New_York);
|
||||
m.insert("pittsburgh", Tz::America__New_York);
|
||||
m.insert("cincinnati", Tz::America__New_York);
|
||||
m.insert("minneapolis", Tz::America__Chicago);
|
||||
m.insert("st louis", Tz::America__Chicago);
|
||||
m.insert("saint louis", Tz::America__Chicago);
|
||||
m.insert("new orleans", Tz::America__Chicago);
|
||||
m.insert("detroit", Tz::America__Detroit);
|
||||
m.insert("salt lake city", Tz::America__Denver);
|
||||
m.insert("honolulu", Tz::Pacific__Honolulu);
|
||||
m.insert("hawaii", Tz::Pacific__Honolulu);
|
||||
m.insert("anchorage", Tz::America__Anchorage);
|
||||
m.insert("alaska", Tz::America__Anchorage);
|
||||
m.insert("boise", Tz::America__Boise);
|
||||
m.insert("richmond", Tz::America__New_York);
|
||||
m.insert("buffalo", Tz::America__New_York);
|
||||
|
||||
// Canada
|
||||
m.insert("toronto", Tz::America__Toronto);
|
||||
m.insert("vancouver", Tz::America__Vancouver);
|
||||
m.insert("montreal", Tz::America__Montreal);
|
||||
m.insert("calgary", Tz::America__Edmonton);
|
||||
m.insert("edmonton", Tz::America__Edmonton);
|
||||
m.insert("ottawa", Tz::America__Toronto);
|
||||
m.insert("winnipeg", Tz::America__Winnipeg);
|
||||
m.insert("halifax", Tz::America__Halifax);
|
||||
m.insert("regina", Tz::America__Regina);
|
||||
|
||||
// Mexico
|
||||
m.insert("mexico city", Tz::America__Mexico_City);
|
||||
m.insert("guadalajara", Tz::America__Mexico_City);
|
||||
m.insert("monterrey", Tz::America__Monterrey);
|
||||
m.insert("cancun", Tz::America__Cancun);
|
||||
m.insert("tijuana", Tz::America__Tijuana);
|
||||
|
||||
// ===== SOUTH AMERICA =====
|
||||
m.insert("sao paulo", Tz::America__Sao_Paulo);
|
||||
m.insert("rio de janeiro", Tz::America__Sao_Paulo);
|
||||
m.insert("rio", Tz::America__Sao_Paulo);
|
||||
m.insert("buenos aires", Tz::America__Argentina__Buenos_Aires);
|
||||
m.insert("santiago", Tz::America__Santiago);
|
||||
m.insert("lima", Tz::America__Lima);
|
||||
m.insert("bogota", Tz::America__Bogota);
|
||||
m.insert("caracas", Tz::America__Caracas);
|
||||
m.insert("montevideo", Tz::America__Montevideo);
|
||||
|
||||
// ===== EUROPE =====
|
||||
m.insert("london", Tz::Europe__London);
|
||||
m.insert("edinburgh", Tz::Europe__London);
|
||||
m.insert("manchester", Tz::Europe__London);
|
||||
m.insert("glasgow", Tz::Europe__London);
|
||||
m.insert("dublin", Tz::Europe__Dublin);
|
||||
m.insert("paris", Tz::Europe__Paris);
|
||||
m.insert("berlin", Tz::Europe__Berlin);
|
||||
m.insert("amsterdam", Tz::Europe__Amsterdam);
|
||||
m.insert("brussels", Tz::Europe__Brussels);
|
||||
m.insert("zurich", Tz::Europe__Zurich);
|
||||
m.insert("geneva", Tz::Europe__Zurich);
|
||||
m.insert("vienna", Tz::Europe__Vienna);
|
||||
m.insert("munich", Tz::Europe__Berlin);
|
||||
m.insert("frankfurt", Tz::Europe__Berlin);
|
||||
m.insert("rome", Tz::Europe__Rome);
|
||||
m.insert("milan", Tz::Europe__Rome);
|
||||
m.insert("madrid", Tz::Europe__Madrid);
|
||||
m.insert("barcelona", Tz::Europe__Madrid);
|
||||
m.insert("lisbon", Tz::Europe__Lisbon);
|
||||
m.insert("athens", Tz::Europe__Athens);
|
||||
m.insert("stockholm", Tz::Europe__Stockholm);
|
||||
m.insert("oslo", Tz::Europe__Oslo);
|
||||
m.insert("copenhagen", Tz::Europe__Copenhagen);
|
||||
m.insert("helsinki", Tz::Europe__Helsinki);
|
||||
m.insert("moscow", Tz::Europe__Moscow);
|
||||
m.insert("st petersburg", Tz::Europe__Moscow);
|
||||
m.insert("saint petersburg", Tz::Europe__Moscow);
|
||||
m.insert("warsaw", Tz::Europe__Warsaw);
|
||||
m.insert("prague", Tz::Europe__Prague);
|
||||
m.insert("budapest", Tz::Europe__Budapest);
|
||||
m.insert("bucharest", Tz::Europe__Bucharest);
|
||||
m.insert("istanbul", Tz::Europe__Istanbul);
|
||||
m.insert("kyiv", Tz::Europe__Kyiv);
|
||||
m.insert("kiev", Tz::Europe__Kyiv);
|
||||
|
||||
// ===== ASIA =====
|
||||
m.insert("tokyo", Tz::Asia__Tokyo);
|
||||
m.insert("osaka", Tz::Asia__Tokyo);
|
||||
m.insert("kyoto", Tz::Asia__Tokyo);
|
||||
m.insert("seoul", Tz::Asia__Seoul);
|
||||
m.insert("busan", Tz::Asia__Seoul);
|
||||
m.insert("beijing", Tz::Asia__Shanghai);
|
||||
m.insert("shanghai", Tz::Asia__Shanghai);
|
||||
m.insert("guangzhou", Tz::Asia__Shanghai);
|
||||
m.insert("shenzhen", Tz::Asia__Shanghai);
|
||||
m.insert("hong kong", Tz::Asia__Hong_Kong);
|
||||
m.insert("taipei", Tz::Asia__Taipei);
|
||||
m.insert("singapore", Tz::Asia__Singapore);
|
||||
m.insert("bangkok", Tz::Asia__Bangkok);
|
||||
m.insert("jakarta", Tz::Asia__Jakarta);
|
||||
m.insert("bali", Tz::Asia__Makassar);
|
||||
m.insert("kuala lumpur", Tz::Asia__Kuala_Lumpur);
|
||||
m.insert("manila", Tz::Asia__Manila);
|
||||
m.insert("ho chi minh city", Tz::Asia__Ho_Chi_Minh);
|
||||
m.insert("saigon", Tz::Asia__Ho_Chi_Minh);
|
||||
m.insert("hanoi", Tz::Asia__Ho_Chi_Minh);
|
||||
m.insert("mumbai", Tz::Asia__Kolkata);
|
||||
m.insert("delhi", Tz::Asia__Kolkata);
|
||||
m.insert("new delhi", Tz::Asia__Kolkata);
|
||||
m.insert("bangalore", Tz::Asia__Kolkata);
|
||||
m.insert("bengaluru", Tz::Asia__Kolkata);
|
||||
m.insert("chennai", Tz::Asia__Kolkata);
|
||||
m.insert("kolkata", Tz::Asia__Kolkata);
|
||||
m.insert("hyderabad", Tz::Asia__Kolkata);
|
||||
m.insert("karachi", Tz::Asia__Karachi);
|
||||
m.insert("lahore", Tz::Asia__Karachi);
|
||||
m.insert("islamabad", Tz::Asia__Karachi);
|
||||
m.insert("dhaka", Tz::Asia__Dhaka);
|
||||
m.insert("colombo", Tz::Asia__Colombo);
|
||||
m.insert("kathmandu", Tz::Asia__Kathmandu);
|
||||
m.insert("dubai", Tz::Asia__Dubai);
|
||||
m.insert("abu dhabi", Tz::Asia__Dubai);
|
||||
m.insert("doha", Tz::Asia__Qatar);
|
||||
m.insert("riyadh", Tz::Asia__Riyadh);
|
||||
m.insert("jeddah", Tz::Asia__Riyadh);
|
||||
m.insert("tehran", Tz::Asia__Tehran);
|
||||
m.insert("baghdad", Tz::Asia__Baghdad);
|
||||
m.insert("beirut", Tz::Asia__Beirut);
|
||||
m.insert("jerusalem", Tz::Asia__Jerusalem);
|
||||
m.insert("tel aviv", Tz::Asia__Jerusalem);
|
||||
m.insert("kabul", Tz::Asia__Kabul);
|
||||
|
||||
// ===== AFRICA =====
|
||||
m.insert("cairo", Tz::Africa__Cairo);
|
||||
m.insert("lagos", Tz::Africa__Lagos);
|
||||
m.insert("nairobi", Tz::Africa__Nairobi);
|
||||
m.insert("johannesburg", Tz::Africa__Johannesburg);
|
||||
m.insert("cape town", Tz::Africa__Johannesburg);
|
||||
m.insert("casablanca", Tz::Africa__Casablanca);
|
||||
m.insert("addis ababa", Tz::Africa__Addis_Ababa);
|
||||
m.insert("accra", Tz::Africa__Accra);
|
||||
m.insert("dakar", Tz::Africa__Dakar);
|
||||
|
||||
// ===== OCEANIA =====
|
||||
m.insert("sydney", Tz::Australia__Sydney);
|
||||
m.insert("melbourne", Tz::Australia__Melbourne);
|
||||
m.insert("brisbane", Tz::Australia__Brisbane);
|
||||
m.insert("perth", Tz::Australia__Perth);
|
||||
m.insert("adelaide", Tz::Australia__Adelaide);
|
||||
m.insert("canberra", Tz::Australia__Sydney);
|
||||
m.insert("darwin", Tz::Australia__Darwin);
|
||||
m.insert("auckland", Tz::Pacific__Auckland);
|
||||
m.insert("wellington", Tz::Pacific__Auckland);
|
||||
|
||||
// ===== COUNTRY ALIASES =====
|
||||
m.insert("japan", Tz::Asia__Tokyo);
|
||||
m.insert("korea", Tz::Asia__Seoul);
|
||||
m.insert("south korea", Tz::Asia__Seoul);
|
||||
m.insert("china", Tz::Asia__Shanghai);
|
||||
m.insert("india", Tz::Asia__Kolkata);
|
||||
m.insert("australia", Tz::Australia__Sydney);
|
||||
m.insert("brazil", Tz::America__Sao_Paulo);
|
||||
m.insert("germany", Tz::Europe__Berlin);
|
||||
m.insert("france", Tz::Europe__Paris);
|
||||
m.insert("spain", Tz::Europe__Madrid);
|
||||
m.insert("italy", Tz::Europe__Rome);
|
||||
m.insert("uk", Tz::Europe__London);
|
||||
m.insert("england", Tz::Europe__London);
|
||||
m.insert("ireland", Tz::Europe__Dublin);
|
||||
m.insert("russia", Tz::Europe__Moscow);
|
||||
m.insert("turkey", Tz::Europe__Istanbul);
|
||||
m.insert("egypt", Tz::Africa__Cairo);
|
||||
m.insert("south africa", Tz::Africa__Johannesburg);
|
||||
m.insert("nigeria", Tz::Africa__Lagos);
|
||||
m.insert("kenya", Tz::Africa__Nairobi);
|
||||
m.insert("thailand", Tz::Asia__Bangkok);
|
||||
m.insert("vietnam", Tz::Asia__Ho_Chi_Minh);
|
||||
m.insert("philippines", Tz::Asia__Manila);
|
||||
m.insert("indonesia", Tz::Asia__Jakarta);
|
||||
m.insert("malaysia", Tz::Asia__Kuala_Lumpur);
|
||||
m.insert("pakistan", Tz::Asia__Karachi);
|
||||
m.insert("bangladesh", Tz::Asia__Dhaka);
|
||||
m.insert("sri lanka", Tz::Asia__Colombo);
|
||||
m.insert("nepal", Tz::Asia__Kathmandu);
|
||||
m.insert("iran", Tz::Asia__Tehran);
|
||||
m.insert("iraq", Tz::Asia__Baghdad);
|
||||
m.insert("saudi arabia", Tz::Asia__Riyadh);
|
||||
m.insert("uae", Tz::Asia__Dubai);
|
||||
m.insert("qatar", Tz::Asia__Qatar);
|
||||
m.insert("israel", Tz::Asia__Jerusalem);
|
||||
m.insert("mexico", Tz::America__Mexico_City);
|
||||
m.insert("argentina", Tz::America__Argentina__Buenos_Aires);
|
||||
m.insert("colombia", Tz::America__Bogota);
|
||||
m.insert("peru", Tz::America__Lima);
|
||||
m.insert("chile", Tz::America__Santiago);
|
||||
m.insert("new zealand", Tz::Pacific__Auckland);
|
||||
m.insert("portugal", Tz::Europe__Lisbon);
|
||||
m.insert("netherlands", Tz::Europe__Amsterdam);
|
||||
m.insert("holland", Tz::Europe__Amsterdam);
|
||||
m.insert("belgium", Tz::Europe__Brussels);
|
||||
m.insert("switzerland", Tz::Europe__Zurich);
|
||||
m.insert("austria", Tz::Europe__Vienna);
|
||||
m.insert("poland", Tz::Europe__Warsaw);
|
||||
m.insert("czech republic", Tz::Europe__Prague);
|
||||
m.insert("czechia", Tz::Europe__Prague);
|
||||
m.insert("hungary", Tz::Europe__Budapest);
|
||||
m.insert("romania", Tz::Europe__Bucharest);
|
||||
m.insert("greece", Tz::Europe__Athens);
|
||||
m.insert("sweden", Tz::Europe__Stockholm);
|
||||
m.insert("norway", Tz::Europe__Oslo);
|
||||
m.insert("denmark", Tz::Europe__Copenhagen);
|
||||
m.insert("finland", Tz::Europe__Helsinki);
|
||||
m.insert("ukraine", Tz::Europe__Kyiv);
|
||||
m.insert("taiwan", Tz::Asia__Taipei);
|
||||
m.insert("morocco", Tz::Africa__Casablanca);
|
||||
|
||||
m
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{TimeZone, Utc};
|
||||
|
||||
#[test]
|
||||
fn test_resolve_city_name() {
|
||||
assert_eq!(resolve_timezone("Tokyo"), Some(Tz::Asia__Tokyo));
|
||||
assert_eq!(resolve_timezone("London"), Some(Tz::Europe__London));
|
||||
assert_eq!(resolve_timezone("New York"), Some(Tz::America__New_York));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_abbreviation() {
|
||||
assert_eq!(resolve_timezone("EST"), Some(Tz::America__New_York));
|
||||
assert_eq!(resolve_timezone("PST"), Some(Tz::America__Los_Angeles));
|
||||
assert_eq!(resolve_timezone("CET"), Some(Tz::Europe__Paris));
|
||||
assert_eq!(resolve_timezone("JST"), Some(Tz::Asia__Tokyo));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_case_insensitive() {
|
||||
assert_eq!(resolve_timezone("tokyo"), Some(Tz::Asia__Tokyo));
|
||||
assert_eq!(resolve_timezone("TOKYO"), Some(Tz::Asia__Tokyo));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_disambiguation() {
|
||||
assert_eq!(resolve_timezone("Portland"), Some(Tz::America__Los_Angeles));
|
||||
assert_eq!(resolve_timezone("Portland, ME"), Some(Tz::America__New_York));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_country() {
|
||||
assert_eq!(resolve_timezone("Japan"), Some(Tz::Asia__Tokyo));
|
||||
assert_eq!(resolve_timezone("India"), Some(Tz::Asia__Kolkata));
|
||||
assert_eq!(resolve_timezone("UK"), Some(Tz::Europe__London));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_unknown() {
|
||||
assert_eq!(resolve_timezone("Narnia"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_tokyo_to_london_winter() {
|
||||
// March 17, 2026: London is still on GMT (DST starts Mar 29).
|
||||
// 3:00 PM JST = 06:00 UTC = 06:00 GMT
|
||||
let source_date = NaiveDate::from_ymd_opt(2026, 3, 17).unwrap();
|
||||
let result = convert_time(
|
||||
15,
|
||||
0,
|
||||
source_date,
|
||||
Tz::Asia__Tokyo,
|
||||
Tz::Europe__London,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(result.hour12, 6);
|
||||
assert_eq!(result.minute, 0);
|
||||
assert!(!result.is_pm);
|
||||
assert_eq!(result.tz_abbr, "GMT");
|
||||
assert!(!result.date_changed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_tokyo_to_london_summer() {
|
||||
// July 15, 2026: London is on BST (UTC+1).
|
||||
// 3:00 PM JST = 06:00 UTC = 07:00 BST
|
||||
let source_date = NaiveDate::from_ymd_opt(2026, 7, 15).unwrap();
|
||||
let result = convert_time(
|
||||
15,
|
||||
0,
|
||||
source_date,
|
||||
Tz::Asia__Tokyo,
|
||||
Tz::Europe__London,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(result.hour12, 7);
|
||||
assert_eq!(result.minute, 0);
|
||||
assert!(!result.is_pm);
|
||||
assert_eq!(result.tz_abbr, "BST");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_date_boundary_crossing() {
|
||||
// 11:00 PM EDT (New York) on March 17 = 03:00 UTC March 18 = 12:00 PM JST March 18
|
||||
let source_date = NaiveDate::from_ymd_opt(2026, 3, 17).unwrap();
|
||||
let result = convert_time(
|
||||
23,
|
||||
0,
|
||||
source_date,
|
||||
Tz::America__New_York,
|
||||
Tz::Asia__Tokyo,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(result.hour12, 12);
|
||||
assert!(result.is_pm);
|
||||
assert_eq!(result.tz_abbr, "JST");
|
||||
assert!(result.date_changed);
|
||||
assert_eq!(result.date, NaiveDate::from_ymd_opt(2026, 3, 18).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_current_time_in_timezone() {
|
||||
let now_utc = Utc.with_ymd_and_hms(2026, 3, 17, 14, 0, 0).unwrap();
|
||||
let result = current_time_in(Tz::America__New_York, now_utc);
|
||||
// March 17 2026 14:00 UTC, NY is EDT (UTC-4) = 10:00 AM
|
||||
assert_eq!(result.hour12, 10);
|
||||
assert_eq!(result.minute, 0);
|
||||
assert!(!result.is_pm);
|
||||
assert_eq!(result.tz_abbr, "EDT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_zoned_time_no_date_change() {
|
||||
let result = ZonedTimeResult {
|
||||
hour12: 6,
|
||||
minute: 0,
|
||||
is_pm: false,
|
||||
tz_abbr: "GMT".to_string(),
|
||||
date: NaiveDate::from_ymd_opt(2026, 3, 17).unwrap(),
|
||||
date_changed: false,
|
||||
};
|
||||
assert_eq!(
|
||||
format_zoned_time(&result, crate::datetime::date_math::DateFormat::US),
|
||||
"6:00 AM GMT"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_zoned_time_with_date_change() {
|
||||
let result = ZonedTimeResult {
|
||||
hour12: 12,
|
||||
minute: 0,
|
||||
is_pm: true,
|
||||
tz_abbr: "JST".to_string(),
|
||||
date: NaiveDate::from_ymd_opt(2026, 3, 18).unwrap(),
|
||||
date_changed: true,
|
||||
};
|
||||
assert_eq!(
|
||||
format_zoned_time(&result, crate::datetime::date_math::DateFormat::US),
|
||||
"12:00 PM JST (March 18, 2026)"
|
||||
);
|
||||
assert_eq!(
|
||||
format_zoned_time(&result, crate::datetime::date_math::DateFormat::EU),
|
||||
"12:00 PM JST (18 March 2026)"
|
||||
);
|
||||
}
|
||||
}
|
||||
156
calcpad-engine/src/datetime/unix.rs
Normal file
156
calcpad-engine/src/datetime/unix.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
//! Unix timestamp conversions: seconds since epoch to/from datetime,
|
||||
//! with optional timezone support.
|
||||
|
||||
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone};
|
||||
use chrono_tz::Tz;
|
||||
|
||||
/// The result of converting a unix timestamp.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct UnixConversion {
|
||||
/// The human-readable date.
|
||||
pub date: NaiveDate,
|
||||
/// The human-readable time.
|
||||
pub time: NaiveTime,
|
||||
/// Whether the result is in UTC or a named timezone.
|
||||
pub tz_label: String,
|
||||
}
|
||||
|
||||
/// Convert a Unix timestamp (seconds since 1970-01-01 00:00:00 UTC) to a
|
||||
/// human-readable date/time in UTC.
|
||||
pub fn from_timestamp_utc(ts: i64) -> Option<UnixConversion> {
|
||||
let dt = DateTime::from_timestamp(ts, 0)?;
|
||||
Some(UnixConversion {
|
||||
date: dt.date_naive(),
|
||||
time: dt.time(),
|
||||
tz_label: "UTC".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert a Unix timestamp to a date/time in a specific timezone.
|
||||
pub fn from_timestamp_in_tz(ts: i64, tz: Tz) -> Option<UnixConversion> {
|
||||
let utc_dt = DateTime::from_timestamp(ts, 0)?;
|
||||
let local_dt = utc_dt.with_timezone(&tz);
|
||||
let tz_label = local_dt.format("%Z").to_string();
|
||||
Some(UnixConversion {
|
||||
date: local_dt.date_naive(),
|
||||
time: local_dt.time(),
|
||||
tz_label,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert a UTC date/time to a Unix timestamp.
|
||||
pub fn to_timestamp_utc(date: NaiveDate, time: NaiveTime) -> Option<i64> {
|
||||
let naive = NaiveDateTime::new(date, time);
|
||||
Some(naive.and_utc().timestamp())
|
||||
}
|
||||
|
||||
/// Convert a local date/time in a specific timezone to a Unix timestamp.
|
||||
pub fn to_timestamp_in_tz(date: NaiveDate, time: NaiveTime, tz: Tz) -> Option<i64> {
|
||||
let naive = NaiveDateTime::new(date, time);
|
||||
let local = tz.from_local_datetime(&naive).single()?;
|
||||
Some(local.timestamp())
|
||||
}
|
||||
|
||||
/// Format a `UnixConversion` for display.
|
||||
pub fn format_unix_conversion(conv: &UnixConversion) -> String {
|
||||
let date_str = conv.date.format("%Y-%m-%d").to_string();
|
||||
let time_str = conv.time.format("%H:%M:%S").to_string();
|
||||
format!("{} {} {}", date_str, time_str, conv.tz_label)
|
||||
}
|
||||
|
||||
/// Format a timestamp as a simple integer string.
|
||||
pub fn format_timestamp(ts: i64) -> String {
|
||||
format!("{}", ts)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn d(y: i32, m: u32, day: u32) -> NaiveDate {
|
||||
NaiveDate::from_ymd_opt(y, m, day).unwrap()
|
||||
}
|
||||
|
||||
fn t(h: u32, m: u32, s: u32) -> NaiveTime {
|
||||
NaiveTime::from_hms_opt(h, m, s).unwrap()
|
||||
}
|
||||
|
||||
// -- from_timestamp_utc --
|
||||
|
||||
#[test]
|
||||
fn test_epoch_zero() {
|
||||
let result = from_timestamp_utc(0).unwrap();
|
||||
assert_eq!(result.date, d(1970, 1, 1));
|
||||
assert_eq!(result.time, t(0, 0, 0));
|
||||
assert_eq!(result.tz_label, "UTC");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_known_timestamp() {
|
||||
// 1700000000 = 2023-11-14 22:13:20 UTC
|
||||
let result = from_timestamp_utc(1_700_000_000).unwrap();
|
||||
assert_eq!(result.date, d(2023, 11, 14));
|
||||
assert_eq!(result.time, t(22, 13, 20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_timestamp() {
|
||||
// -86400 = 1969-12-31 00:00:00 UTC
|
||||
let result = from_timestamp_utc(-86400).unwrap();
|
||||
assert_eq!(result.date, d(1969, 12, 31));
|
||||
assert_eq!(result.time, t(0, 0, 0));
|
||||
}
|
||||
|
||||
// -- from_timestamp_in_tz --
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_in_tokyo() {
|
||||
// 0 = 1970-01-01 09:00:00 JST
|
||||
let result = from_timestamp_in_tz(0, Tz::Asia__Tokyo).unwrap();
|
||||
assert_eq!(result.date, d(1970, 1, 1));
|
||||
assert_eq!(result.time, t(9, 0, 0));
|
||||
assert_eq!(result.tz_label, "JST");
|
||||
}
|
||||
|
||||
// -- to_timestamp_utc --
|
||||
|
||||
#[test]
|
||||
fn test_to_timestamp_epoch() {
|
||||
let ts = to_timestamp_utc(d(1970, 1, 1), t(0, 0, 0)).unwrap();
|
||||
assert_eq!(ts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_utc() {
|
||||
let original_ts: i64 = 1_700_000_000;
|
||||
let conv = from_timestamp_utc(original_ts).unwrap();
|
||||
let ts = to_timestamp_utc(conv.date, conv.time).unwrap();
|
||||
assert_eq!(ts, original_ts);
|
||||
}
|
||||
|
||||
// -- to_timestamp_in_tz --
|
||||
|
||||
#[test]
|
||||
fn test_to_timestamp_tokyo() {
|
||||
// 1970-01-01 09:00:00 JST should be epoch 0
|
||||
let ts = to_timestamp_in_tz(d(1970, 1, 1), t(9, 0, 0), Tz::Asia__Tokyo).unwrap();
|
||||
assert_eq!(ts, 0);
|
||||
}
|
||||
|
||||
// -- format --
|
||||
|
||||
#[test]
|
||||
fn test_format_unix_conversion() {
|
||||
let conv = UnixConversion {
|
||||
date: d(2023, 11, 14),
|
||||
time: t(22, 13, 20),
|
||||
tz_label: "UTC".to_string(),
|
||||
};
|
||||
assert_eq!(format_unix_conversion(&conv), "2023-11-14 22:13:20 UTC");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_timestamp() {
|
||||
assert_eq!(format_timestamp(1_700_000_000), "1700000000");
|
||||
}
|
||||
}
|
||||
241
calcpad-engine/src/functions/combinatorics.rs
Normal file
241
calcpad-engine/src/functions/combinatorics.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
//! Factorial and combinatorics: factorial, nPr, nCr.
|
||||
//!
|
||||
//! Uses arbitrary-precision internally via u128 for intermediate products
|
||||
//! to handle factorials up to ~34. For truly large factorials (100!), callers
|
||||
//! should use `factorial_bigint` which returns a string. The f64-based
|
||||
//! `factorial` registered here will overflow gracefully to `f64::INFINITY`
|
||||
//! for n > 170 (standard IEEE 754 limit).
|
||||
|
||||
use super::{FunctionError, FunctionRegistry};
|
||||
|
||||
/// Compute n! as f64. Returns +Infinity when n > 170.
|
||||
fn factorial_f64(n: f64) -> Result<f64, FunctionError> {
|
||||
if n < 0.0 || n.fract() != 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Factorial is only defined for non-negative integers",
|
||||
));
|
||||
}
|
||||
let n = n as u64;
|
||||
let mut result: f64 = 1.0;
|
||||
for i in 2..=n {
|
||||
result *= i as f64;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn factorial_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
factorial_f64(args[0])
|
||||
}
|
||||
|
||||
/// Compute nPr = n! / (n-k)! as f64.
|
||||
fn permutation_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let n = args[0];
|
||||
let k = args[1];
|
||||
|
||||
if n.fract() != 0.0 || k.fract() != 0.0 {
|
||||
return Err(FunctionError::new("nPr requires integer arguments"));
|
||||
}
|
||||
if n < 0.0 || k < 0.0 {
|
||||
return Err(FunctionError::new("nPr requires non-negative arguments"));
|
||||
}
|
||||
|
||||
let n = n as u64;
|
||||
let k = k as u64;
|
||||
|
||||
if k > n {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
let mut result: f64 = 1.0;
|
||||
for i in 0..k {
|
||||
result *= (n - i) as f64;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Compute nCr = n! / (k! * (n-k)!) as f64.
|
||||
fn combination_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let n = args[0];
|
||||
let k = args[1];
|
||||
|
||||
if n.fract() != 0.0 || k.fract() != 0.0 {
|
||||
return Err(FunctionError::new("nCr requires integer arguments"));
|
||||
}
|
||||
if n < 0.0 || k < 0.0 {
|
||||
return Err(FunctionError::new("nCr requires non-negative arguments"));
|
||||
}
|
||||
|
||||
let n = n as u64;
|
||||
let mut k = k as u64;
|
||||
|
||||
if k > n {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
// Optimise: C(n,k) == C(n, n-k)
|
||||
if k > n - k {
|
||||
k = n - k;
|
||||
}
|
||||
|
||||
let mut result: f64 = 1.0;
|
||||
for i in 0..k {
|
||||
result *= (n - i) as f64;
|
||||
result /= (i + 1) as f64;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Register combinatorics functions.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_fixed("factorial", 1, factorial_fn);
|
||||
reg.register_fixed("nPr", 2, permutation_fn);
|
||||
reg.register_fixed("nCr", 2, combination_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- factorial ---
|
||||
|
||||
#[test]
|
||||
fn factorial_zero_is_one() {
|
||||
let v = reg().call("factorial", &[0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial_one_is_one() {
|
||||
let v = reg().call("factorial", &[1.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial_five_is_120() {
|
||||
let v = reg().call("factorial", &[5.0]).unwrap();
|
||||
assert!((v - 120.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial_ten_is_3628800() {
|
||||
let v = reg().call("factorial", &[10.0]).unwrap();
|
||||
assert!((v - 3_628_800.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial_20() {
|
||||
let v = reg().call("factorial", &[20.0]).unwrap();
|
||||
// 20! = 2432902008176640000
|
||||
assert!((v - 2_432_902_008_176_640_000.0).abs() < 1e3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial_negative_error() {
|
||||
let err = reg().call("factorial", &[-3.0]).unwrap_err();
|
||||
assert!(err.message.contains("non-negative integers"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial_non_integer_error() {
|
||||
let err = reg().call("factorial", &[3.5]).unwrap_err();
|
||||
assert!(err.message.contains("non-negative integers"));
|
||||
}
|
||||
|
||||
// --- nPr ---
|
||||
|
||||
#[test]
|
||||
fn npr_10_3_is_720() {
|
||||
let v = reg().call("nPr", &[10.0, 3.0]).unwrap();
|
||||
assert!((v - 720.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npr_5_5_is_120() {
|
||||
let v = reg().call("nPr", &[5.0, 5.0]).unwrap();
|
||||
assert!((v - 120.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npr_5_0_is_1() {
|
||||
let v = reg().call("nPr", &[5.0, 0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npr_0_0_is_1() {
|
||||
let v = reg().call("nPr", &[0.0, 0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npr_k_greater_than_n_is_zero() {
|
||||
let v = reg().call("nPr", &[5.0, 7.0]).unwrap();
|
||||
assert!((v - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npr_negative_error() {
|
||||
let err = reg().call("nPr", &[-1.0, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("non-negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npr_non_integer_error() {
|
||||
let err = reg().call("nPr", &[5.5, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("integer"));
|
||||
}
|
||||
|
||||
// --- nCr ---
|
||||
|
||||
#[test]
|
||||
fn ncr_10_3_is_120() {
|
||||
let v = reg().call("nCr", &[10.0, 3.0]).unwrap();
|
||||
assert!((v - 120.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_5_2_is_10() {
|
||||
let v = reg().call("nCr", &[5.0, 2.0]).unwrap();
|
||||
assert!((v - 10.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_5_0_is_1() {
|
||||
let v = reg().call("nCr", &[5.0, 0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_5_5_is_1() {
|
||||
let v = reg().call("nCr", &[5.0, 5.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_k_greater_than_n_is_zero() {
|
||||
let v = reg().call("nCr", &[5.0, 7.0]).unwrap();
|
||||
assert!((v - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_0_0_is_1() {
|
||||
let v = reg().call("nCr", &[0.0, 0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_negative_error() {
|
||||
let err = reg().call("nCr", &[-1.0, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("non-negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ncr_non_integer_error() {
|
||||
let err = reg().call("nCr", &[5.5, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("integer"));
|
||||
}
|
||||
}
|
||||
184
calcpad-engine/src/functions/financial.rs
Normal file
184
calcpad-engine/src/functions/financial.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
//! Financial functions: compound interest and mortgage payment.
|
||||
//!
|
||||
//! ## compound_interest(principal, rate, periods)
|
||||
//!
|
||||
//! Returns `principal * (1 + rate)^periods`.
|
||||
//! - `principal` — initial investment / loan amount
|
||||
//! - `rate` — interest rate per period (e.g. 0.05 for 5%)
|
||||
//! - `periods` — number of compounding periods
|
||||
//!
|
||||
//! ## mortgage_payment(principal, annual_rate, years)
|
||||
//!
|
||||
//! Returns the monthly payment for a fixed-rate mortgage using the standard
|
||||
//! amortization formula:
|
||||
//!
|
||||
//! M = P * [r(1+r)^n] / [(1+r)^n - 1]
|
||||
//!
|
||||
//! where `r` = `annual_rate / 12` and `n` = `years * 12`.
|
||||
//!
|
||||
//! If the rate is 0, returns `principal / (years * 12)`.
|
||||
|
||||
use super::{FunctionError, FunctionRegistry};
|
||||
|
||||
/// compound_interest(principal, rate, periods)
|
||||
fn compound_interest_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let principal = args[0];
|
||||
let rate = args[1];
|
||||
let periods = args[2];
|
||||
|
||||
if rate < -1.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Interest rate must be >= -1 (i.e. at most a 100% loss per period)",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(principal * (1.0 + rate).powf(periods))
|
||||
}
|
||||
|
||||
/// mortgage_payment(principal, annual_rate, years)
|
||||
fn mortgage_payment_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let principal = args[0];
|
||||
let annual_rate = args[1];
|
||||
let years = args[2];
|
||||
|
||||
if principal < 0.0 {
|
||||
return Err(FunctionError::new("Principal must be non-negative"));
|
||||
}
|
||||
if annual_rate < 0.0 {
|
||||
return Err(FunctionError::new("Annual rate must be non-negative"));
|
||||
}
|
||||
if years <= 0.0 {
|
||||
return Err(FunctionError::new("Loan term must be positive"));
|
||||
}
|
||||
|
||||
let n = years * 12.0; // total monthly payments
|
||||
|
||||
if annual_rate == 0.0 {
|
||||
// No interest — just divide evenly.
|
||||
return Ok(principal / n);
|
||||
}
|
||||
|
||||
let r = annual_rate / 12.0; // monthly rate
|
||||
let factor = (1.0 + r).powf(n);
|
||||
let payment = principal * (r * factor) / (factor - 1.0);
|
||||
Ok(payment)
|
||||
}
|
||||
|
||||
/// Register financial functions.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_fixed("compound_interest", 3, compound_interest_fn);
|
||||
reg.register_fixed("mortgage_payment", 3, mortgage_payment_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- compound_interest ---
|
||||
|
||||
#[test]
|
||||
fn compound_interest_basic() {
|
||||
// $1000 at 5% for 10 years => 1000 * 1.05^10 = 1628.89...
|
||||
let v = reg()
|
||||
.call("compound_interest", &[1000.0, 0.05, 10.0])
|
||||
.unwrap();
|
||||
assert!((v - 1628.894627).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compound_interest_zero_rate() {
|
||||
let v = reg()
|
||||
.call("compound_interest", &[1000.0, 0.0, 10.0])
|
||||
.unwrap();
|
||||
assert!((v - 1000.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compound_interest_zero_periods() {
|
||||
let v = reg()
|
||||
.call("compound_interest", &[1000.0, 0.05, 0.0])
|
||||
.unwrap();
|
||||
assert!((v - 1000.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compound_interest_one_period() {
|
||||
let v = reg()
|
||||
.call("compound_interest", &[1000.0, 0.1, 1.0])
|
||||
.unwrap();
|
||||
assert!((v - 1100.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compound_interest_negative_rate_too_low() {
|
||||
let err = reg()
|
||||
.call("compound_interest", &[1000.0, -1.5, 1.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("rate"));
|
||||
}
|
||||
|
||||
// --- mortgage_payment ---
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_standard() {
|
||||
// $200,000 at 6% annual for 30 years => ~$1199.10/month
|
||||
let v = reg()
|
||||
.call("mortgage_payment", &[200_000.0, 0.06, 30.0])
|
||||
.unwrap();
|
||||
assert!((v - 1199.10).abs() < 0.02);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_zero_rate() {
|
||||
// $120,000 at 0% for 10 years => $1000/month
|
||||
let v = reg()
|
||||
.call("mortgage_payment", &[120_000.0, 0.0, 10.0])
|
||||
.unwrap();
|
||||
assert!((v - 1000.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_short_term() {
|
||||
// $12,000 at 12% annual for 1 year => ~$1066.19/month
|
||||
let v = reg()
|
||||
.call("mortgage_payment", &[12_000.0, 0.12, 1.0])
|
||||
.unwrap();
|
||||
assert!((v - 1066.19).abs() < 0.02);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_negative_principal_error() {
|
||||
let err = reg()
|
||||
.call("mortgage_payment", &[-1000.0, 0.05, 10.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("Principal"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_negative_rate_error() {
|
||||
let err = reg()
|
||||
.call("mortgage_payment", &[1000.0, -0.05, 10.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("rate"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_zero_years_error() {
|
||||
let err = reg()
|
||||
.call("mortgage_payment", &[1000.0, 0.05, 0.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("term"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mortgage_payment_arity_error() {
|
||||
let err = reg()
|
||||
.call("mortgage_payment", &[1000.0, 0.05])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("expects 3 argument"));
|
||||
}
|
||||
}
|
||||
223
calcpad-engine/src/functions/list_ops.rs
Normal file
223
calcpad-engine/src/functions/list_ops.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Variadic list operations: min, max, gcd, lcm.
|
||||
//!
|
||||
//! All accept 1 or more arguments. `gcd` and `lcm` require integer arguments.
|
||||
|
||||
use super::{FunctionError, FunctionRegistry};
|
||||
|
||||
fn min_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
// We already know args.len() >= 1 from the variadic guard.
|
||||
let mut m = args[0];
|
||||
for &v in &args[1..] {
|
||||
if v < m {
|
||||
m = v;
|
||||
}
|
||||
}
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
fn max_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let mut m = args[0];
|
||||
for &v in &args[1..] {
|
||||
if v > m {
|
||||
m = v;
|
||||
}
|
||||
}
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
/// GCD of two non-negative integers using Euclidean algorithm.
|
||||
fn gcd_pair(mut a: i64, mut b: i64) -> i64 {
|
||||
a = a.abs();
|
||||
b = b.abs();
|
||||
while b != 0 {
|
||||
let t = b;
|
||||
b = a % b;
|
||||
a = t;
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
/// LCM of two non-negative integers.
|
||||
fn lcm_pair(a: i64, b: i64) -> i64 {
|
||||
if a == 0 || b == 0 {
|
||||
return 0;
|
||||
}
|
||||
(a.abs() / gcd_pair(a, b)) * b.abs()
|
||||
}
|
||||
|
||||
fn gcd_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
for &v in args {
|
||||
if v.fract() != 0.0 {
|
||||
return Err(FunctionError::new("gcd requires integer arguments"));
|
||||
}
|
||||
}
|
||||
let mut result = args[0] as i64;
|
||||
for &v in &args[1..] {
|
||||
result = gcd_pair(result, v as i64);
|
||||
}
|
||||
Ok(result as f64)
|
||||
}
|
||||
|
||||
fn lcm_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
for &v in args {
|
||||
if v.fract() != 0.0 {
|
||||
return Err(FunctionError::new("lcm requires integer arguments"));
|
||||
}
|
||||
}
|
||||
let mut result = args[0] as i64;
|
||||
for &v in &args[1..] {
|
||||
result = lcm_pair(result, v as i64);
|
||||
}
|
||||
Ok(result as f64)
|
||||
}
|
||||
|
||||
/// Register list-operation functions.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_variadic("min", 1, min_fn);
|
||||
reg.register_variadic("max", 1, max_fn);
|
||||
reg.register_variadic("gcd", 1, gcd_fn);
|
||||
reg.register_variadic("lcm", 1, lcm_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- min ---
|
||||
|
||||
#[test]
|
||||
fn min_single() {
|
||||
let v = reg().call("min", &[42.0]).unwrap();
|
||||
assert!((v - 42.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_two() {
|
||||
let v = reg().call("min", &[3.0, 7.0]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_many() {
|
||||
let v = reg().call("min", &[10.0, 3.0, 7.0, 1.0, 5.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_negative() {
|
||||
let v = reg().call("min", &[-5.0, -2.0, -10.0]).unwrap();
|
||||
assert!((v - (-10.0)).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_no_args_error() {
|
||||
let err = reg().call("min", &[]).unwrap_err();
|
||||
assert!(err.message.contains("at least 1"));
|
||||
}
|
||||
|
||||
// --- max ---
|
||||
|
||||
#[test]
|
||||
fn max_single() {
|
||||
let v = reg().call("max", &[42.0]).unwrap();
|
||||
assert!((v - 42.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_two() {
|
||||
let v = reg().call("max", &[3.0, 7.0]).unwrap();
|
||||
assert!((v - 7.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_many() {
|
||||
let v = reg().call("max", &[10.0, 3.0, 7.0, 1.0, 50.0]).unwrap();
|
||||
assert!((v - 50.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_no_args_error() {
|
||||
let err = reg().call("max", &[]).unwrap_err();
|
||||
assert!(err.message.contains("at least 1"));
|
||||
}
|
||||
|
||||
// --- gcd ---
|
||||
|
||||
#[test]
|
||||
fn gcd_two_numbers() {
|
||||
let v = reg().call("gcd", &[12.0, 8.0]).unwrap();
|
||||
assert!((v - 4.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gcd_three_numbers() {
|
||||
let v = reg().call("gcd", &[12.0, 8.0, 6.0]).unwrap();
|
||||
assert!((v - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gcd_coprime() {
|
||||
let v = reg().call("gcd", &[7.0, 13.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gcd_with_zero() {
|
||||
let v = reg().call("gcd", &[0.0, 5.0]).unwrap();
|
||||
assert!((v - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gcd_single() {
|
||||
let v = reg().call("gcd", &[42.0]).unwrap();
|
||||
assert!((v - 42.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gcd_non_integer_error() {
|
||||
let err = reg().call("gcd", &[3.5, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("integer"));
|
||||
}
|
||||
|
||||
// --- lcm ---
|
||||
|
||||
#[test]
|
||||
fn lcm_two_numbers() {
|
||||
let v = reg().call("lcm", &[4.0, 6.0]).unwrap();
|
||||
assert!((v - 12.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lcm_three_numbers() {
|
||||
let v = reg().call("lcm", &[4.0, 6.0, 10.0]).unwrap();
|
||||
assert!((v - 60.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lcm_with_zero() {
|
||||
let v = reg().call("lcm", &[0.0, 5.0]).unwrap();
|
||||
assert!((v - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lcm_single() {
|
||||
let v = reg().call("lcm", &[42.0]).unwrap();
|
||||
assert!((v - 42.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lcm_coprime() {
|
||||
let v = reg().call("lcm", &[7.0, 13.0]).unwrap();
|
||||
assert!((v - 91.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lcm_non_integer_error() {
|
||||
let err = reg().call("lcm", &[3.5, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("integer"));
|
||||
}
|
||||
}
|
||||
249
calcpad-engine/src/functions/logarithmic.rs
Normal file
249
calcpad-engine/src/functions/logarithmic.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
//! Logarithmic, exponential, and root functions.
|
||||
//!
|
||||
//! - `ln` — natural logarithm (base e)
|
||||
//! - `log` — common logarithm (base 10)
|
||||
//! - `log2` — binary logarithm (base 2)
|
||||
//! - `exp` — e raised to a power
|
||||
//! - `pow` — base raised to an exponent (2 args)
|
||||
//! - `sqrt` — square root
|
||||
//! - `cbrt` — cube root
|
||||
|
||||
use super::{FunctionError, FunctionRegistry};
|
||||
|
||||
fn ln_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let x = args[0];
|
||||
if x <= 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Argument out of domain for ln (must be positive)",
|
||||
));
|
||||
}
|
||||
Ok(x.ln())
|
||||
}
|
||||
|
||||
fn log_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let x = args[0];
|
||||
if x <= 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Argument out of domain for log (must be positive)",
|
||||
));
|
||||
}
|
||||
Ok(x.log10())
|
||||
}
|
||||
|
||||
fn log2_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let x = args[0];
|
||||
if x <= 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Argument out of domain for log2 (must be positive)",
|
||||
));
|
||||
}
|
||||
Ok(x.log2())
|
||||
}
|
||||
|
||||
fn exp_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
Ok(args[0].exp())
|
||||
}
|
||||
|
||||
fn pow_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
Ok(args[0].powf(args[1]))
|
||||
}
|
||||
|
||||
fn sqrt_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let x = args[0];
|
||||
if x < 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Argument out of domain for sqrt (must be non-negative)",
|
||||
));
|
||||
}
|
||||
Ok(x.sqrt())
|
||||
}
|
||||
|
||||
fn cbrt_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
Ok(args[0].cbrt())
|
||||
}
|
||||
|
||||
/// Register all logarithmic/exponential/root functions.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_fixed("ln", 1, ln_fn);
|
||||
reg.register_fixed("log", 1, log_fn);
|
||||
reg.register_fixed("log2", 1, log2_fn);
|
||||
reg.register_fixed("exp", 1, exp_fn);
|
||||
reg.register_fixed("pow", 2, pow_fn);
|
||||
reg.register_fixed("sqrt", 1, sqrt_fn);
|
||||
reg.register_fixed("cbrt", 1, cbrt_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- ln ---
|
||||
|
||||
#[test]
|
||||
fn ln_one_is_zero() {
|
||||
let v = reg().call("ln", &[1.0]).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ln_e_is_one() {
|
||||
let v = reg().call("ln", &[std::f64::consts::E]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ln_zero_domain_error() {
|
||||
let err = reg().call("ln", &[0.0]).unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ln_negative_domain_error() {
|
||||
let err = reg().call("ln", &[-1.0]).unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
// --- log (base 10) ---
|
||||
|
||||
#[test]
|
||||
fn log_100_is_2() {
|
||||
let v = reg().call("log", &[100.0]).unwrap();
|
||||
assert!((v - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_1000_is_3() {
|
||||
let v = reg().call("log", &[1000.0]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_negative_domain_error() {
|
||||
let err = reg().call("log", &[-1.0]).unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
// --- log2 ---
|
||||
|
||||
#[test]
|
||||
fn log2_256_is_8() {
|
||||
let v = reg().call("log2", &[256.0]).unwrap();
|
||||
assert!((v - 8.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log2_one_is_zero() {
|
||||
let v = reg().call("log2", &[1.0]).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log2_negative_domain_error() {
|
||||
let err = reg().call("log2", &[-5.0]).unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
// --- exp ---
|
||||
|
||||
#[test]
|
||||
fn exp_zero_is_one() {
|
||||
let v = reg().call("exp", &[0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exp_one_is_e() {
|
||||
let v = reg().call("exp", &[1.0]).unwrap();
|
||||
assert!((v - std::f64::consts::E).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- pow ---
|
||||
|
||||
#[test]
|
||||
fn pow_2_10_is_1024() {
|
||||
let v = reg().call("pow", &[2.0, 10.0]).unwrap();
|
||||
assert!((v - 1024.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow_3_0_is_1() {
|
||||
let v = reg().call("pow", &[3.0, 0.0]).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- sqrt ---
|
||||
|
||||
#[test]
|
||||
fn sqrt_144_is_12() {
|
||||
let v = reg().call("sqrt", &[144.0]).unwrap();
|
||||
assert!((v - 12.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sqrt_2_approx() {
|
||||
let v = reg().call("sqrt", &[2.0]).unwrap();
|
||||
assert!((v - std::f64::consts::SQRT_2).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sqrt_zero_is_zero() {
|
||||
let v = reg().call("sqrt", &[0.0]).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sqrt_negative_domain_error() {
|
||||
let err = reg().call("sqrt", &[-4.0]).unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
// --- cbrt ---
|
||||
|
||||
#[test]
|
||||
fn cbrt_27_is_3() {
|
||||
let v = reg().call("cbrt", &[27.0]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cbrt_8_is_2() {
|
||||
let v = reg().call("cbrt", &[8.0]).unwrap();
|
||||
assert!((v - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cbrt_neg_8_is_neg_2() {
|
||||
let v = reg().call("cbrt", &[-8.0]).unwrap();
|
||||
assert!((v - (-2.0)).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- composition ---
|
||||
|
||||
#[test]
|
||||
fn ln_exp_roundtrip() {
|
||||
let r = reg();
|
||||
let inner = r.call("exp", &[5.0]).unwrap();
|
||||
let v = r.call("ln", &[inner]).unwrap();
|
||||
assert!((v - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exp_ln_roundtrip() {
|
||||
let r = reg();
|
||||
let inner = r.call("ln", &[10.0]).unwrap();
|
||||
let v = r.call("exp", &[inner]).unwrap();
|
||||
assert!((v - 10.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sqrt_pow_roundtrip() {
|
||||
let r = reg();
|
||||
let inner = r.call("pow", &[3.0, 2.0]).unwrap();
|
||||
let v = r.call("sqrt", &[inner]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
321
calcpad-engine/src/functions/mod.rs
Normal file
321
calcpad-engine/src/functions/mod.rs
Normal file
@@ -0,0 +1,321 @@
|
||||
//! Function registry and dispatch for CalcPad math functions.
|
||||
//!
|
||||
//! Provides a [`FunctionRegistry`] that maps function names to typed
|
||||
//! implementations across all function categories (trig, logarithmic,
|
||||
//! combinatorics, financial, rounding, list ops, timecodes).
|
||||
|
||||
pub mod combinatorics;
|
||||
pub mod financial;
|
||||
pub mod list_ops;
|
||||
pub mod logarithmic;
|
||||
pub mod rounding;
|
||||
pub mod timecodes;
|
||||
pub mod trig;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Error type returned by function evaluation.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct FunctionError {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl FunctionError {
|
||||
pub fn new(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FunctionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for FunctionError {}
|
||||
|
||||
/// Angle mode for trigonometric functions.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AngleMode {
|
||||
Radians,
|
||||
Degrees,
|
||||
}
|
||||
|
||||
impl Default for AngleMode {
|
||||
fn default() -> Self {
|
||||
AngleMode::Radians
|
||||
}
|
||||
}
|
||||
|
||||
/// The signature of a function: how many args it accepts and how to call it.
|
||||
#[derive(Clone)]
|
||||
enum FnImpl {
|
||||
/// Fixed-arity function (e.g. sin takes 1 arg, pow takes 2).
|
||||
Fixed {
|
||||
arity: usize,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
},
|
||||
/// Variadic function that accepts 1..N args (e.g. min, max, gcd, lcm).
|
||||
Variadic {
|
||||
min_args: usize,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
},
|
||||
/// Angle-aware trig function (1 arg + angle mode + force-degrees flag).
|
||||
Trig {
|
||||
func: fn(f64, AngleMode, bool) -> Result<f64, FunctionError>,
|
||||
},
|
||||
/// Variable-arity function with a known range (e.g. round takes 1 or 2).
|
||||
RangeArity {
|
||||
min_args: usize,
|
||||
max_args: usize,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
},
|
||||
/// Timecode function that operates on string-like frame values.
|
||||
Timecode {
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Central registry mapping function names to their implementations.
|
||||
pub struct FunctionRegistry {
|
||||
functions: HashMap<String, FnImpl>,
|
||||
}
|
||||
|
||||
impl FunctionRegistry {
|
||||
/// Build a new registry pre-loaded with all built-in functions.
|
||||
pub fn new() -> Self {
|
||||
let mut reg = Self {
|
||||
functions: HashMap::new(),
|
||||
};
|
||||
trig::register(&mut reg);
|
||||
logarithmic::register(&mut reg);
|
||||
combinatorics::register(&mut reg);
|
||||
financial::register(&mut reg);
|
||||
rounding::register(&mut reg);
|
||||
list_ops::register(&mut reg);
|
||||
timecodes::register(&mut reg);
|
||||
reg
|
||||
}
|
||||
|
||||
// ---- registration helpers (called by sub-modules) ----
|
||||
|
||||
pub(crate) fn register_trig(
|
||||
&mut self,
|
||||
name: &str,
|
||||
func: fn(f64, AngleMode, bool) -> Result<f64, FunctionError>,
|
||||
) {
|
||||
self.functions
|
||||
.insert(name.to_string(), FnImpl::Trig { func });
|
||||
}
|
||||
|
||||
pub(crate) fn register_fixed(
|
||||
&mut self,
|
||||
name: &str,
|
||||
arity: usize,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
) {
|
||||
self.functions
|
||||
.insert(name.to_string(), FnImpl::Fixed { arity, func });
|
||||
}
|
||||
|
||||
pub(crate) fn register_variadic(
|
||||
&mut self,
|
||||
name: &str,
|
||||
min_args: usize,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
) {
|
||||
self.functions
|
||||
.insert(name.to_string(), FnImpl::Variadic { min_args, func });
|
||||
}
|
||||
|
||||
pub(crate) fn register_range_arity(
|
||||
&mut self,
|
||||
name: &str,
|
||||
min_args: usize,
|
||||
max_args: usize,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
) {
|
||||
self.functions.insert(
|
||||
name.to_string(),
|
||||
FnImpl::RangeArity {
|
||||
min_args,
|
||||
max_args,
|
||||
func,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn register_timecode(
|
||||
&mut self,
|
||||
name: &str,
|
||||
func: fn(&[f64]) -> Result<f64, FunctionError>,
|
||||
) {
|
||||
self.functions
|
||||
.insert(name.to_string(), FnImpl::Timecode { func });
|
||||
}
|
||||
|
||||
// ---- dispatch ----
|
||||
|
||||
/// Returns true if `name` is a registered function.
|
||||
pub fn has_function(&self, name: &str) -> bool {
|
||||
self.functions.contains_key(name)
|
||||
}
|
||||
|
||||
/// Returns true if `name` is a trig function (needs angle mode).
|
||||
pub fn is_trig(&self, name: &str) -> bool {
|
||||
matches!(self.functions.get(name), Some(FnImpl::Trig { .. }))
|
||||
}
|
||||
|
||||
/// Call a trig function with the given argument, angle mode, and
|
||||
/// force-degrees flag.
|
||||
pub fn call_trig(
|
||||
&self,
|
||||
name: &str,
|
||||
arg: f64,
|
||||
mode: AngleMode,
|
||||
force_degrees: bool,
|
||||
) -> Result<f64, FunctionError> {
|
||||
match self.functions.get(name) {
|
||||
Some(FnImpl::Trig { func }) => func(arg, mode, force_degrees),
|
||||
Some(_) => Err(FunctionError::new(format!(
|
||||
"{} is not a trigonometric function",
|
||||
name
|
||||
))),
|
||||
None => Err(FunctionError::new(format!("Unknown function: {}", name))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Call any non-trig function with a slice of evaluated arguments.
|
||||
pub fn call(&self, name: &str, args: &[f64]) -> Result<f64, FunctionError> {
|
||||
match self.functions.get(name) {
|
||||
Some(FnImpl::Fixed { arity, func }) => {
|
||||
if args.len() != *arity {
|
||||
return Err(FunctionError::new(format!(
|
||||
"{} expects {} argument(s), got {}",
|
||||
name,
|
||||
arity,
|
||||
args.len()
|
||||
)));
|
||||
}
|
||||
func(args)
|
||||
}
|
||||
Some(FnImpl::Variadic { min_args, func }) => {
|
||||
if args.len() < *min_args {
|
||||
return Err(FunctionError::new(format!(
|
||||
"{} requires at least {} argument(s), got {}",
|
||||
name,
|
||||
min_args,
|
||||
args.len()
|
||||
)));
|
||||
}
|
||||
func(args)
|
||||
}
|
||||
Some(FnImpl::RangeArity {
|
||||
min_args,
|
||||
max_args,
|
||||
func,
|
||||
}) => {
|
||||
if args.len() < *min_args || args.len() > *max_args {
|
||||
return Err(FunctionError::new(format!(
|
||||
"{} expects {}-{} argument(s), got {}",
|
||||
name,
|
||||
min_args,
|
||||
max_args,
|
||||
args.len()
|
||||
)));
|
||||
}
|
||||
func(args)
|
||||
}
|
||||
Some(FnImpl::Trig { func }) => {
|
||||
// Convenience: if called via `call()`, default to radians, no force-degrees.
|
||||
if args.len() != 1 {
|
||||
return Err(FunctionError::new(format!(
|
||||
"{} expects 1 argument, got {}",
|
||||
name,
|
||||
args.len()
|
||||
)));
|
||||
}
|
||||
func(args[0], AngleMode::Radians, false)
|
||||
}
|
||||
Some(FnImpl::Timecode { func }) => func(args),
|
||||
None => Err(FunctionError::new(format!("Unknown function: {}", name))),
|
||||
}
|
||||
}
|
||||
|
||||
/// List all registered function names (sorted alphabetically).
|
||||
pub fn function_names(&self) -> Vec<&str> {
|
||||
let mut names: Vec<&str> = self.functions.keys().map(|s| s.as_str()).collect();
|
||||
names.sort();
|
||||
names
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FunctionRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn registry_contains_all_categories() {
|
||||
let reg = FunctionRegistry::new();
|
||||
// Trig
|
||||
assert!(reg.has_function("sin"));
|
||||
assert!(reg.has_function("cos"));
|
||||
assert!(reg.has_function("tanh"));
|
||||
// Logarithmic
|
||||
assert!(reg.has_function("ln"));
|
||||
assert!(reg.has_function("log"));
|
||||
assert!(reg.has_function("sqrt"));
|
||||
// Combinatorics
|
||||
assert!(reg.has_function("factorial"));
|
||||
assert!(reg.has_function("nPr"));
|
||||
assert!(reg.has_function("nCr"));
|
||||
// Financial
|
||||
assert!(reg.has_function("compound_interest"));
|
||||
assert!(reg.has_function("mortgage_payment"));
|
||||
// Rounding
|
||||
assert!(reg.has_function("round"));
|
||||
assert!(reg.has_function("floor"));
|
||||
assert!(reg.has_function("ceil"));
|
||||
// List
|
||||
assert!(reg.has_function("min"));
|
||||
assert!(reg.has_function("max"));
|
||||
assert!(reg.has_function("gcd"));
|
||||
assert!(reg.has_function("lcm"));
|
||||
// Timecodes
|
||||
assert!(reg.has_function("tc_to_frames"));
|
||||
assert!(reg.has_function("frames_to_tc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trig_dispatch_works() {
|
||||
let reg = FunctionRegistry::new();
|
||||
assert!(reg.is_trig("sin"));
|
||||
let val = reg
|
||||
.call_trig("sin", 0.0, AngleMode::Radians, false)
|
||||
.unwrap();
|
||||
assert!((val - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_unknown_function_returns_error() {
|
||||
let reg = FunctionRegistry::new();
|
||||
let err = reg.call("nonexistent_fn", &[1.0]).unwrap_err();
|
||||
assert!(err.message.contains("Unknown function"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arity_mismatch_returns_error() {
|
||||
let reg = FunctionRegistry::new();
|
||||
let err = reg.call("sqrt", &[1.0, 2.0]).unwrap_err();
|
||||
assert!(err.message.contains("expects 1 argument"));
|
||||
}
|
||||
}
|
||||
191
calcpad-engine/src/functions/rounding.rs
Normal file
191
calcpad-engine/src/functions/rounding.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
//! Rounding functions: round, floor, ceil, round_to.
|
||||
//!
|
||||
//! - `round(x)` — round to the nearest integer (half rounds away from 0)
|
||||
//! - `round(x, n)` — round to n decimal places
|
||||
//! - `floor(x)` — round toward negative infinity
|
||||
//! - `ceil(x)` — round toward positive infinity
|
||||
//! - `round_to(x, step)` — round x to the nearest multiple of step
|
||||
|
||||
use super::{FunctionError, FunctionRegistry};
|
||||
|
||||
fn floor_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
Ok(args[0].floor())
|
||||
}
|
||||
|
||||
fn ceil_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
Ok(args[0].ceil())
|
||||
}
|
||||
|
||||
fn round_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let value = args[0];
|
||||
if args.len() == 1 {
|
||||
return Ok(value.round());
|
||||
}
|
||||
let decimals = args[1];
|
||||
if decimals.fract() != 0.0 || decimals < 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"round decimal places must be a non-negative integer",
|
||||
));
|
||||
}
|
||||
let factor = 10f64.powi(decimals as i32);
|
||||
Ok((value * factor).round() / factor)
|
||||
}
|
||||
|
||||
/// Round x to the nearest multiple of step.
|
||||
/// `round_to(17, 5)` => 15, `round_to(18, 5)` => 20.
|
||||
fn round_to_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
let x = args[0];
|
||||
let step = args[1];
|
||||
if step == 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"round_to step must be non-zero",
|
||||
));
|
||||
}
|
||||
Ok((x / step).round() * step)
|
||||
}
|
||||
|
||||
/// Register rounding functions.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_fixed("floor", 1, floor_fn);
|
||||
reg.register_fixed("ceil", 1, ceil_fn);
|
||||
reg.register_range_arity("round", 1, 2, round_fn);
|
||||
reg.register_fixed("round_to", 2, round_to_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- floor ---
|
||||
|
||||
#[test]
|
||||
fn floor_positive_fraction() {
|
||||
let v = reg().call("floor", &[3.7]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_negative_fraction() {
|
||||
let v = reg().call("floor", &[-3.2]).unwrap();
|
||||
assert!((v - (-4.0)).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_integer_unchanged() {
|
||||
let v = reg().call("floor", &[5.0]).unwrap();
|
||||
assert!((v - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_zero() {
|
||||
let v = reg().call("floor", &[0.0]).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- ceil ---
|
||||
|
||||
#[test]
|
||||
fn ceil_positive_fraction() {
|
||||
let v = reg().call("ceil", &[3.2]).unwrap();
|
||||
assert!((v - 4.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ceil_negative_fraction() {
|
||||
let v = reg().call("ceil", &[-3.7]).unwrap();
|
||||
assert!((v - (-3.0)).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ceil_integer_unchanged() {
|
||||
let v = reg().call("ceil", &[5.0]).unwrap();
|
||||
assert!((v - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ceil_zero() {
|
||||
let v = reg().call("ceil", &[0.0]).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- round ---
|
||||
|
||||
#[test]
|
||||
fn round_half_up() {
|
||||
let v = reg().call("round", &[2.5]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_1_5() {
|
||||
let v = reg().call("round", &[1.5]).unwrap();
|
||||
assert!((v - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_negative() {
|
||||
let v = reg().call("round", &[-1.5]).unwrap();
|
||||
// Rust's f64::round rounds half away from zero, so -1.5 => -2.0
|
||||
assert!((v - (-2.0)).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_with_decimal_places() {
|
||||
let v = reg().call("round", &[3.456, 2.0]).unwrap();
|
||||
assert!((v - 3.46).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_with_zero_places() {
|
||||
let v = reg().call("round", &[3.456, 0.0]).unwrap();
|
||||
assert!((v - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_with_one_place() {
|
||||
let v = reg().call("round", &[1.234, 1.0]).unwrap();
|
||||
assert!((v - 1.2).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_negative_decimal_places_error() {
|
||||
let err = reg().call("round", &[3.456, -1.0]).unwrap_err();
|
||||
assert!(err.message.contains("non-negative"));
|
||||
}
|
||||
|
||||
// --- round_to (nearest N) ---
|
||||
|
||||
#[test]
|
||||
fn round_to_nearest_5() {
|
||||
let v = reg().call("round_to", &[17.0, 5.0]).unwrap();
|
||||
assert!((v - 15.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_to_nearest_5_up() {
|
||||
let v = reg().call("round_to", &[18.0, 5.0]).unwrap();
|
||||
assert!((v - 20.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_to_nearest_10() {
|
||||
let v = reg().call("round_to", &[84.0, 10.0]).unwrap();
|
||||
assert!((v - 80.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_to_nearest_0_25() {
|
||||
let v = reg().call("round_to", &[3.3, 0.25]).unwrap();
|
||||
assert!((v - 3.25).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_to_zero_step_error() {
|
||||
let err = reg().call("round_to", &[10.0, 0.0]).unwrap_err();
|
||||
assert!(err.message.contains("non-zero"));
|
||||
}
|
||||
}
|
||||
366
calcpad-engine/src/functions/timecodes.rs
Normal file
366
calcpad-engine/src/functions/timecodes.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
//! Video timecode arithmetic.
|
||||
//!
|
||||
//! Timecodes represent positions in video as `HH:MM:SS:FF` where FF is a
|
||||
//! frame count within the current second. The number of frames per second
|
||||
//! (fps) determines the range of FF (0..fps-1).
|
||||
//!
|
||||
//! ## Functions
|
||||
//!
|
||||
//! - `tc_to_frames(hours, minutes, seconds, frames, fps)` — convert a
|
||||
//! timecode to a total frame count.
|
||||
//! - `frames_to_tc(total_frames, fps)` — convert total frames back to a
|
||||
//! packed timecode value `HH * 1_000_000 + MM * 10_000 + SS * 100 + FF`
|
||||
//! for easy extraction of components.
|
||||
//! - `tc_add_frames(hours, minutes, seconds, frames, fps, add_frames)` —
|
||||
//! add (or subtract) a number of frames to a timecode and return the new
|
||||
//! total frame count.
|
||||
//!
|
||||
//! Common fps values: 24, 25, 29.97 (NTSC drop-frame), 30, 48, 60.
|
||||
//!
|
||||
//! For now we work in non-drop-frame (NDF) mode. Drop-frame support can be
|
||||
//! added later.
|
||||
|
||||
use super::{FunctionError, FunctionRegistry};
|
||||
|
||||
/// Convert a timecode (H, M, S, F, fps) to total frame count.
|
||||
fn tc_to_frames_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
if args.len() != 5 {
|
||||
return Err(FunctionError::new(
|
||||
"tc_to_frames expects 5 arguments: hours, minutes, seconds, frames, fps",
|
||||
));
|
||||
}
|
||||
let hours = args[0];
|
||||
let minutes = args[1];
|
||||
let seconds = args[2];
|
||||
let frames = args[3];
|
||||
let fps = args[4];
|
||||
|
||||
validate_timecode_components(hours, minutes, seconds, frames, fps)?;
|
||||
|
||||
let fps_i = fps as u64;
|
||||
let total = (hours as u64) * 3600 * fps_i
|
||||
+ (minutes as u64) * 60 * fps_i
|
||||
+ (seconds as u64) * fps_i
|
||||
+ (frames as u64);
|
||||
Ok(total as f64)
|
||||
}
|
||||
|
||||
/// Convert total frames to a packed timecode: HH*1_000_000 + MM*10_000 + SS*100 + FF.
|
||||
/// Returns the packed value. Also returns components via the packed encoding.
|
||||
fn frames_to_tc_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
if args.len() != 2 {
|
||||
return Err(FunctionError::new(
|
||||
"frames_to_tc expects 2 arguments: total_frames, fps",
|
||||
));
|
||||
}
|
||||
let total = args[0];
|
||||
let fps = args[1];
|
||||
|
||||
if total < 0.0 || total.fract() != 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"total_frames must be a non-negative integer",
|
||||
));
|
||||
}
|
||||
if fps <= 0.0 || fps.fract() != 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"fps must be a positive integer",
|
||||
));
|
||||
}
|
||||
|
||||
let total = total as u64;
|
||||
let fps_i = fps as u64;
|
||||
|
||||
let ff = total % fps_i;
|
||||
let rem = total / fps_i;
|
||||
let ss = rem % 60;
|
||||
let rem = rem / 60;
|
||||
let mm = rem % 60;
|
||||
let hh = rem / 60;
|
||||
|
||||
// Pack into a single number: HH_MM_SS_FF
|
||||
let packed = hh * 1_000_000 + mm * 10_000 + ss * 100 + ff;
|
||||
Ok(packed as f64)
|
||||
}
|
||||
|
||||
/// Add frames to a timecode and return new total frame count.
|
||||
fn tc_add_frames_fn(args: &[f64]) -> Result<f64, FunctionError> {
|
||||
if args.len() != 6 {
|
||||
return Err(FunctionError::new(
|
||||
"tc_add_frames expects 6 arguments: hours, minutes, seconds, frames, fps, add_frames",
|
||||
));
|
||||
}
|
||||
let hours = args[0];
|
||||
let minutes = args[1];
|
||||
let seconds = args[2];
|
||||
let frames = args[3];
|
||||
let fps = args[4];
|
||||
let add_frames = args[5];
|
||||
|
||||
validate_timecode_components(hours, minutes, seconds, frames, fps)?;
|
||||
|
||||
let fps_i = fps as u64;
|
||||
let total = (hours as u64) * 3600 * fps_i
|
||||
+ (minutes as u64) * 60 * fps_i
|
||||
+ (seconds as u64) * fps_i
|
||||
+ (frames as u64);
|
||||
|
||||
let new_total = total as i64 + add_frames as i64;
|
||||
if new_total < 0 {
|
||||
return Err(FunctionError::new(
|
||||
"Resulting timecode would be negative",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(new_total as f64)
|
||||
}
|
||||
|
||||
fn validate_timecode_components(
|
||||
hours: f64,
|
||||
minutes: f64,
|
||||
seconds: f64,
|
||||
frames: f64,
|
||||
fps: f64,
|
||||
) -> Result<(), FunctionError> {
|
||||
if fps <= 0.0 || fps.fract() != 0.0 {
|
||||
return Err(FunctionError::new("fps must be a positive integer"));
|
||||
}
|
||||
if hours < 0.0 || hours.fract() != 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"hours must be a non-negative integer",
|
||||
));
|
||||
}
|
||||
if minutes < 0.0 || minutes >= 60.0 || minutes.fract() != 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"minutes must be an integer in 0..59",
|
||||
));
|
||||
}
|
||||
if seconds < 0.0 || seconds >= 60.0 || seconds.fract() != 0.0 {
|
||||
return Err(FunctionError::new(
|
||||
"seconds must be an integer in 0..59",
|
||||
));
|
||||
}
|
||||
if frames < 0.0 || frames >= fps || frames.fract() != 0.0 {
|
||||
return Err(FunctionError::new(format!(
|
||||
"frames must be an integer in 0..{} (fps={})",
|
||||
fps as u64 - 1,
|
||||
fps as u64,
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register timecode functions.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_fixed("tc_to_frames", 5, tc_to_frames_fn);
|
||||
reg.register_fixed("frames_to_tc", 2, frames_to_tc_fn);
|
||||
reg.register_fixed("tc_add_frames", 6, tc_add_frames_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- tc_to_frames ---
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_zero() {
|
||||
// 00:00:00:00 at 24fps => 0 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[0.0, 0.0, 0.0, 0.0, 24.0])
|
||||
.unwrap();
|
||||
assert!((v - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_one_second_24fps() {
|
||||
// 00:00:01:00 at 24fps => 24 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[0.0, 0.0, 1.0, 0.0, 24.0])
|
||||
.unwrap();
|
||||
assert!((v - 24.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_one_minute_24fps() {
|
||||
// 00:01:00:00 at 24fps => 1440 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[0.0, 1.0, 0.0, 0.0, 24.0])
|
||||
.unwrap();
|
||||
assert!((v - 1440.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_one_hour_24fps() {
|
||||
// 01:00:00:00 at 24fps => 86400 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[1.0, 0.0, 0.0, 0.0, 24.0])
|
||||
.unwrap();
|
||||
assert!((v - 86400.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_mixed() {
|
||||
// 01:02:03:04 at 24fps => 1*3600*24 + 2*60*24 + 3*24 + 4 = 86400 + 2880 + 72 + 4 = 89356
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[1.0, 2.0, 3.0, 4.0, 24.0])
|
||||
.unwrap();
|
||||
assert!((v - 89356.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_25fps() {
|
||||
// 00:00:01:00 at 25fps => 25 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[0.0, 0.0, 1.0, 0.0, 25.0])
|
||||
.unwrap();
|
||||
assert!((v - 25.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_30fps() {
|
||||
// 00:00:01:00 at 30fps => 30 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[0.0, 0.0, 1.0, 0.0, 30.0])
|
||||
.unwrap();
|
||||
assert!((v - 30.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_60fps() {
|
||||
// 00:01:00:00 at 60fps => 3600 frames
|
||||
let v = reg()
|
||||
.call("tc_to_frames", &[0.0, 1.0, 0.0, 0.0, 60.0])
|
||||
.unwrap();
|
||||
assert!((v - 3600.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_invalid_minutes() {
|
||||
let err = reg()
|
||||
.call("tc_to_frames", &[0.0, 60.0, 0.0, 0.0, 24.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("minutes"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_invalid_frames() {
|
||||
// Frames >= fps is invalid
|
||||
let err = reg()
|
||||
.call("tc_to_frames", &[0.0, 0.0, 0.0, 24.0, 24.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("frames"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_to_frames_invalid_fps() {
|
||||
let err = reg()
|
||||
.call("tc_to_frames", &[0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("fps"));
|
||||
}
|
||||
|
||||
// --- frames_to_tc ---
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_zero() {
|
||||
let v = reg().call("frames_to_tc", &[0.0, 24.0]).unwrap();
|
||||
assert!((v - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_one_second() {
|
||||
// 24 frames at 24fps => 00:00:01:00 => packed 100
|
||||
let v = reg().call("frames_to_tc", &[24.0, 24.0]).unwrap();
|
||||
assert!((v - 100.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_one_minute() {
|
||||
// 1440 frames at 24fps => 00:01:00:00 => packed 10000
|
||||
let v = reg().call("frames_to_tc", &[1440.0, 24.0]).unwrap();
|
||||
assert!((v - 10000.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_one_hour() {
|
||||
// 86400 frames at 24fps => 01:00:00:00 => packed 1000000
|
||||
let v = reg().call("frames_to_tc", &[86400.0, 24.0]).unwrap();
|
||||
assert!((v - 1_000_000.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_mixed() {
|
||||
// 89356 frames at 24fps => 01:02:03:04 => packed 1020304
|
||||
let v = reg().call("frames_to_tc", &[89356.0, 24.0]).unwrap();
|
||||
assert!((v - 1_020_304.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_roundtrip() {
|
||||
let r = reg();
|
||||
// Convert to frames then back
|
||||
let frames = r
|
||||
.call("tc_to_frames", &[2.0, 30.0, 15.0, 12.0, 30.0])
|
||||
.unwrap();
|
||||
let packed = r.call("frames_to_tc", &[frames, 30.0]).unwrap();
|
||||
// 02:30:15:12 => packed 2301512
|
||||
assert!((packed - 2_301_512.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frames_to_tc_negative_error() {
|
||||
let err = reg().call("frames_to_tc", &[-1.0, 24.0]).unwrap_err();
|
||||
assert!(err.message.contains("non-negative"));
|
||||
}
|
||||
|
||||
// --- tc_add_frames ---
|
||||
|
||||
#[test]
|
||||
fn tc_add_frames_simple() {
|
||||
let r = reg();
|
||||
// 00:00:00:00 at 24fps + 48 frames => 48
|
||||
let v = r
|
||||
.call("tc_add_frames", &[0.0, 0.0, 0.0, 0.0, 24.0, 48.0])
|
||||
.unwrap();
|
||||
assert!((v - 48.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_add_frames_subtract() {
|
||||
let r = reg();
|
||||
// 00:00:02:00 at 24fps = 48 frames, subtract 24 => 24
|
||||
let v = r
|
||||
.call("tc_add_frames", &[0.0, 0.0, 2.0, 0.0, 24.0, -24.0])
|
||||
.unwrap();
|
||||
assert!((v - 24.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_add_frames_negative_result_error() {
|
||||
let r = reg();
|
||||
let err = r
|
||||
.call("tc_add_frames", &[0.0, 0.0, 0.0, 0.0, 24.0, -1.0])
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tc_add_frames_cross_minute_boundary() {
|
||||
let r = reg();
|
||||
// 00:00:59:23 at 24fps + 1 frame
|
||||
let base = r
|
||||
.call("tc_to_frames", &[0.0, 0.0, 59.0, 23.0, 24.0])
|
||||
.unwrap();
|
||||
let v = r
|
||||
.call("tc_add_frames", &[0.0, 0.0, 59.0, 23.0, 24.0, 1.0])
|
||||
.unwrap();
|
||||
assert!((v - (base + 1.0)).abs() < 1e-10);
|
||||
// Verify the result converts to 00:01:00:00
|
||||
let packed = r.call("frames_to_tc", &[v, 24.0]).unwrap();
|
||||
assert!((packed - 10000.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
255
calcpad-engine/src/functions/trig.rs
Normal file
255
calcpad-engine/src/functions/trig.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Trigonometric functions: sin, cos, tan, asin, acos, atan, sinh, cosh, tanh.
|
||||
//!
|
||||
//! All trig functions are angle-mode aware. When `AngleMode::Degrees` is active
|
||||
//! (or when `force_degrees` is true), inputs to forward trig functions are
|
||||
//! converted from degrees to radians, and outputs of inverse trig functions
|
||||
//! are converted from radians to degrees. Hyperbolic functions ignore angle mode.
|
||||
|
||||
use super::{AngleMode, FunctionError, FunctionRegistry};
|
||||
|
||||
const DEG_TO_RAD: f64 = std::f64::consts::PI / 180.0;
|
||||
const RAD_TO_DEG: f64 = 180.0 / std::f64::consts::PI;
|
||||
|
||||
fn to_radians(value: f64, mode: AngleMode, force_degrees: bool) -> f64 {
|
||||
if force_degrees || mode == AngleMode::Degrees {
|
||||
value * DEG_TO_RAD
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
fn from_radians(value: f64, mode: AngleMode) -> f64 {
|
||||
if mode == AngleMode::Degrees {
|
||||
value * RAD_TO_DEG
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
// --- forward trig ---
|
||||
|
||||
fn sin_fn(arg: f64, mode: AngleMode, force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(to_radians(arg, mode, force_deg).sin())
|
||||
}
|
||||
|
||||
fn cos_fn(arg: f64, mode: AngleMode, force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(to_radians(arg, mode, force_deg).cos())
|
||||
}
|
||||
|
||||
fn tan_fn(arg: f64, mode: AngleMode, force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(to_radians(arg, mode, force_deg).tan())
|
||||
}
|
||||
|
||||
// --- inverse trig ---
|
||||
|
||||
fn asin_fn(arg: f64, mode: AngleMode, _force_deg: bool) -> Result<f64, FunctionError> {
|
||||
if arg < -1.0 || arg > 1.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Argument out of domain for asin (must be between -1 and 1)",
|
||||
));
|
||||
}
|
||||
Ok(from_radians(arg.asin(), mode))
|
||||
}
|
||||
|
||||
fn acos_fn(arg: f64, mode: AngleMode, _force_deg: bool) -> Result<f64, FunctionError> {
|
||||
if arg < -1.0 || arg > 1.0 {
|
||||
return Err(FunctionError::new(
|
||||
"Argument out of domain for acos (must be between -1 and 1)",
|
||||
));
|
||||
}
|
||||
Ok(from_radians(arg.acos(), mode))
|
||||
}
|
||||
|
||||
fn atan_fn(arg: f64, mode: AngleMode, _force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(from_radians(arg.atan(), mode))
|
||||
}
|
||||
|
||||
// --- hyperbolic (angle-mode independent) ---
|
||||
|
||||
fn sinh_fn(arg: f64, _mode: AngleMode, _force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(arg.sinh())
|
||||
}
|
||||
|
||||
fn cosh_fn(arg: f64, _mode: AngleMode, _force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(arg.cosh())
|
||||
}
|
||||
|
||||
fn tanh_fn(arg: f64, _mode: AngleMode, _force_deg: bool) -> Result<f64, FunctionError> {
|
||||
Ok(arg.tanh())
|
||||
}
|
||||
|
||||
/// Register all trig functions into the given registry.
|
||||
pub fn register(reg: &mut FunctionRegistry) {
|
||||
reg.register_trig("sin", sin_fn);
|
||||
reg.register_trig("cos", cos_fn);
|
||||
reg.register_trig("tan", tan_fn);
|
||||
reg.register_trig("asin", asin_fn);
|
||||
reg.register_trig("acos", acos_fn);
|
||||
reg.register_trig("atan", atan_fn);
|
||||
reg.register_trig("sinh", sinh_fn);
|
||||
reg.register_trig("cosh", cosh_fn);
|
||||
reg.register_trig("tanh", tanh_fn);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI, SQRT_2};
|
||||
|
||||
fn reg() -> FunctionRegistry {
|
||||
FunctionRegistry::new()
|
||||
}
|
||||
|
||||
// --- radians mode (default) ---
|
||||
|
||||
#[test]
|
||||
fn sin_zero_radians() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("sin", 0.0, AngleMode::Radians, false).unwrap();
|
||||
assert!((v - 0.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sin_pi_radians() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("sin", PI, AngleMode::Radians, false).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_zero_is_one() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("cos", 0.0, AngleMode::Radians, false).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tan_zero_is_zero() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("tan", 0.0, AngleMode::Radians, false).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asin_one_is_pi_over_2() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("asin", 1.0, AngleMode::Radians, false).unwrap();
|
||||
assert!((v - FRAC_PI_2).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acos_one_is_zero() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("acos", 1.0, AngleMode::Radians, false).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn atan_one_is_pi_over_4() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("atan", 1.0, AngleMode::Radians, false).unwrap();
|
||||
assert!((v - FRAC_PI_4).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- degrees mode ---
|
||||
|
||||
#[test]
|
||||
fn sin_90_degrees_is_one() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("sin", 90.0, AngleMode::Degrees, false).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_zero_degrees_is_one() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("cos", 0.0, AngleMode::Degrees, false).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acos_half_degrees_is_60() {
|
||||
let r = reg();
|
||||
let v = r
|
||||
.call_trig("acos", 0.5, AngleMode::Degrees, false)
|
||||
.unwrap();
|
||||
assert!((v - 60.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn atan_one_degrees_is_45() {
|
||||
let r = reg();
|
||||
let v = r
|
||||
.call_trig("atan", 1.0, AngleMode::Degrees, false)
|
||||
.unwrap();
|
||||
assert!((v - 45.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- force-degrees override (radians mode, but degree symbol present) ---
|
||||
|
||||
#[test]
|
||||
fn sin_45_force_degrees_in_rad_mode() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("sin", 45.0, AngleMode::Radians, true).unwrap();
|
||||
assert!((v - SQRT_2 / 2.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tan_45_force_degrees() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("tan", 45.0, AngleMode::Radians, true).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- hyperbolic ---
|
||||
|
||||
#[test]
|
||||
fn sinh_one() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("sinh", 1.0, AngleMode::Radians, false).unwrap();
|
||||
assert!((v - 1.1752011936438014).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosh_zero_is_one() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("cosh", 0.0, AngleMode::Radians, false).unwrap();
|
||||
assert!((v - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tanh_zero_is_zero() {
|
||||
let r = reg();
|
||||
let v = r.call_trig("tanh", 0.0, AngleMode::Radians, false).unwrap();
|
||||
assert!(v.abs() < 1e-10);
|
||||
}
|
||||
|
||||
// --- domain errors ---
|
||||
|
||||
#[test]
|
||||
fn asin_out_of_domain() {
|
||||
let r = reg();
|
||||
let err = r
|
||||
.call_trig("asin", 2.0, AngleMode::Radians, false)
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn asin_negative_out_of_domain() {
|
||||
let r = reg();
|
||||
let err = r
|
||||
.call_trig("asin", -2.0, AngleMode::Radians, false)
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acos_out_of_domain() {
|
||||
let r = reg();
|
||||
let err = r
|
||||
.call_trig("acos", 2.0, AngleMode::Radians, false)
|
||||
.unwrap_err();
|
||||
assert!(err.message.contains("out of domain"));
|
||||
}
|
||||
}
|
||||
@@ -131,6 +131,27 @@ fn eval_inner(expr: &Expr, ctx: &mut EvalContext) -> Result<Value, String> {
|
||||
ctx.set_variable(name, result);
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
ExprKind::LineRef(line_num) => {
|
||||
let key = format!("__line_{}", line_num);
|
||||
if let Some(result) = ctx.get_variable(&key) {
|
||||
result_to_value(result)
|
||||
} else {
|
||||
Err(format!("invalid line reference: line {}", line_num))
|
||||
}
|
||||
}
|
||||
|
||||
ExprKind::PrevRef => {
|
||||
if let Some(result) = ctx.get_variable("__prev") {
|
||||
result_to_value(result)
|
||||
} else {
|
||||
Err("no previous line result".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
ExprKind::FunctionCall { name, args } => {
|
||||
eval_function_call(name, args, ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -486,6 +507,67 @@ fn format_duration(value: f64, unit: DurationUnit) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_function_call(
|
||||
name: &str,
|
||||
args: &[Expr],
|
||||
ctx: &mut EvalContext,
|
||||
) -> Result<Value, String> {
|
||||
if args.len() != 1 {
|
||||
return Err(format!(
|
||||
"function '{}' expects 1 argument, got {}",
|
||||
name,
|
||||
args.len()
|
||||
));
|
||||
}
|
||||
let arg_val = eval_inner(&args[0], ctx)?;
|
||||
let n = match &arg_val {
|
||||
Value::Number(v) => *v,
|
||||
Value::UnitValue { value, .. } => *value,
|
||||
Value::CurrencyValue { amount, .. } => *amount,
|
||||
_ => return Err(format!("function '{}' requires a numeric argument", name)),
|
||||
};
|
||||
let result = match name.to_lowercase().as_str() {
|
||||
"sqrt" => {
|
||||
if n < 0.0 {
|
||||
return Err("sqrt of negative number".to_string());
|
||||
}
|
||||
n.sqrt()
|
||||
}
|
||||
"abs" => n.abs(),
|
||||
"round" => n.round(),
|
||||
"floor" => n.floor(),
|
||||
"ceil" => n.ceil(),
|
||||
"log" | "log10" => {
|
||||
if n <= 0.0 {
|
||||
return Err("log of non-positive number".to_string());
|
||||
}
|
||||
n.log10()
|
||||
}
|
||||
"ln" => {
|
||||
if n <= 0.0 {
|
||||
return Err("ln of non-positive number".to_string());
|
||||
}
|
||||
n.ln()
|
||||
}
|
||||
"sin" => n.sin(),
|
||||
"cos" => n.cos(),
|
||||
"tan" => n.tan(),
|
||||
_ => return Err(format!("unknown function: {}", name)),
|
||||
};
|
||||
// Preserve unit/currency context through functions
|
||||
match arg_val {
|
||||
Value::UnitValue { unit, .. } => Ok(Value::UnitValue {
|
||||
value: result,
|
||||
unit,
|
||||
}),
|
||||
Value::CurrencyValue { currency, .. } => Ok(Value::CurrencyValue {
|
||||
amount: result,
|
||||
currency,
|
||||
}),
|
||||
_ => Ok(Value::Number(result)),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_units(value: f64, from: &str, to: &str) -> Option<f64> {
|
||||
// Normalize to base unit, then convert to target
|
||||
let (base_value, base_unit) = to_base_unit(value, from)?;
|
||||
|
||||
@@ -80,6 +80,8 @@ impl<'a> Lexer<'a> {
|
||||
| TokenKind::Unit(_)
|
||||
| TokenKind::LParen
|
||||
| TokenKind::RParen
|
||||
| TokenKind::LineRef(_)
|
||||
| TokenKind::PrevRef
|
||||
)
|
||||
});
|
||||
// A single identifier token (potential variable reference) is also calculable
|
||||
@@ -151,6 +153,13 @@ impl<'a> Lexer<'a> {
|
||||
return Some(self.scan_comment());
|
||||
}
|
||||
|
||||
// Hash line reference: #N
|
||||
if b == b'#' {
|
||||
if let Some(tok) = self.try_scan_hash_line_ref() {
|
||||
return Some(tok);
|
||||
}
|
||||
}
|
||||
|
||||
// Currency symbols
|
||||
if let Some(tok) = self.try_scan_currency() {
|
||||
return Some(tok);
|
||||
@@ -358,6 +367,26 @@ impl<'a> Lexer<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_scan_hash_line_ref(&mut self) -> Option<Token> {
|
||||
// Check if the character after # is a digit
|
||||
if let Some(next) = self.peek_ahead(1) {
|
||||
if next.is_ascii_digit() {
|
||||
let start = self.pos;
|
||||
self.pos += 1; // skip '#'
|
||||
let num_start = self.pos;
|
||||
self.consume_digits();
|
||||
let num_str = &self.input[num_start..self.pos];
|
||||
if let Ok(line_num) = num_str.parse::<usize>() {
|
||||
return Some(Token::new(
|
||||
TokenKind::LineRef(line_num),
|
||||
Span::new(start, self.pos),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn scan_word(&mut self) -> Token {
|
||||
// "divided by" two-word operator
|
||||
if self.matches_word("divided") {
|
||||
@@ -392,6 +421,49 @@ impl<'a> Lexer<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
// prev / ans keywords (previous line reference)
|
||||
if self.matches_word("prev") {
|
||||
let start = self.pos;
|
||||
self.pos += 4;
|
||||
return Token::new(TokenKind::PrevRef, Span::new(start, self.pos));
|
||||
}
|
||||
if self.matches_word("ans") {
|
||||
let start = self.pos;
|
||||
self.pos += 3;
|
||||
return Token::new(TokenKind::PrevRef, Span::new(start, self.pos));
|
||||
}
|
||||
|
||||
// lineN syntax for line references (e.g., line1, line42)
|
||||
// Can't use matches_word since "line1" has a digit after "line"
|
||||
{
|
||||
let remaining = &self.input[self.pos..];
|
||||
if remaining.len() >= 5
|
||||
&& remaining[..4].eq_ignore_ascii_case("line")
|
||||
&& remaining.as_bytes()[4].is_ascii_digit()
|
||||
{
|
||||
let start = self.pos;
|
||||
self.pos += 4; // skip "line"
|
||||
let num_start = self.pos;
|
||||
self.consume_digits();
|
||||
let num_str = &self.input[num_start..self.pos];
|
||||
if let Ok(line_num) = num_str.parse::<usize>() {
|
||||
return Token::new(
|
||||
TokenKind::LineRef(line_num),
|
||||
Span::new(start, self.pos),
|
||||
);
|
||||
}
|
||||
// Failed to parse — revert to identifier
|
||||
self.pos = start;
|
||||
while self.pos < self.bytes.len()
|
||||
&& (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
|
||||
{
|
||||
self.pos += 1;
|
||||
}
|
||||
let word = self.input[start..self.pos].to_string();
|
||||
return Token::new(TokenKind::Identifier(word), Span::new(start, self.pos));
|
||||
}
|
||||
}
|
||||
|
||||
// The keyword `in` (for conversions)
|
||||
if self.matches_word("in") {
|
||||
let start = self.pos;
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
pub mod ast;
|
||||
pub mod context;
|
||||
pub mod currency;
|
||||
pub mod datetime;
|
||||
pub mod error;
|
||||
pub mod ffi;
|
||||
pub mod functions;
|
||||
pub mod interpreter;
|
||||
pub mod lexer;
|
||||
pub mod number;
|
||||
@@ -11,6 +14,8 @@ pub mod sheet_context;
|
||||
pub mod span;
|
||||
pub mod token;
|
||||
pub mod types;
|
||||
pub mod units;
|
||||
pub mod variables;
|
||||
|
||||
pub use context::EvalContext;
|
||||
pub use ffi::{FfiResponse, FfiSheetResponse};
|
||||
@@ -19,3 +24,4 @@ pub use pipeline::{eval_line, eval_sheet};
|
||||
pub use sheet_context::SheetContext;
|
||||
pub use span::Span;
|
||||
pub use types::{CalcResult, CalcValue, ResultMetadata, ResultType};
|
||||
pub use variables::{CompletionContext, CompletionItem, CompletionKind, CompletionResult};
|
||||
|
||||
@@ -174,9 +174,47 @@ impl Parser {
|
||||
let name = name.clone();
|
||||
let span = tok.span;
|
||||
self.advance();
|
||||
|
||||
// Check for function call: Identifier followed by '('
|
||||
if !self.at_end() && self.check(&TokenKind::LParen) {
|
||||
self.advance(); // consume '('
|
||||
let mut args = Vec::new();
|
||||
// Parse arguments (comma-separated)
|
||||
if !self.check(&TokenKind::RParen) {
|
||||
let arg = self.parse_expr(Precedence::None)?;
|
||||
args.push(arg);
|
||||
// For future: could add comma support here
|
||||
}
|
||||
if !self.check(&TokenKind::RParen) {
|
||||
return Err(ParseError::new(
|
||||
"expected closing parenthesis ')' after function arguments",
|
||||
self.current_span(),
|
||||
));
|
||||
}
|
||||
let close_span = self.peek().span;
|
||||
self.advance();
|
||||
return Ok(Spanned::new(
|
||||
ExprKind::FunctionCall { name, args },
|
||||
span.merge(close_span),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Spanned::new(ExprKind::Identifier(name), span))
|
||||
}
|
||||
|
||||
TokenKind::LineRef(line_num) => {
|
||||
let line_num = *line_num;
|
||||
let span = tok.span;
|
||||
self.advance();
|
||||
Ok(Spanned::new(ExprKind::LineRef(line_num), span))
|
||||
}
|
||||
|
||||
TokenKind::PrevRef => {
|
||||
let span = tok.span;
|
||||
self.advance();
|
||||
Ok(Spanned::new(ExprKind::PrevRef, span))
|
||||
}
|
||||
|
||||
_ => Err(ParseError::new(
|
||||
format!("unexpected token: {:?}", tok.kind),
|
||||
tok.span,
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::lexer;
|
||||
use crate::parser;
|
||||
use crate::span::Span;
|
||||
use crate::types::CalcResult;
|
||||
use crate::variables::aggregators::{self, AggregatorKind};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// A parsed line in the sheet.
|
||||
@@ -22,6 +23,10 @@ struct LineEntry {
|
||||
references_vars: Vec<String>,
|
||||
/// Whether this line is a non-calculable text/comment line.
|
||||
is_text: bool,
|
||||
/// Whether this line is a heading (e.g., `## Title`).
|
||||
is_heading: bool,
|
||||
/// The aggregator kind, if this line is a standalone aggregator keyword.
|
||||
aggregator: Option<AggregatorKind>,
|
||||
}
|
||||
|
||||
/// SheetContext holds all evaluation state for a multi-line sheet.
|
||||
@@ -62,6 +67,27 @@ impl SheetContext {
|
||||
|
||||
let trimmed = source.trim();
|
||||
|
||||
// Detect headings and aggregator keywords before tokenizing
|
||||
let is_heading = aggregators::is_heading(trimmed);
|
||||
let aggregator = aggregators::detect_aggregator(trimmed);
|
||||
|
||||
// Headings and aggregators are handled specially, not parsed as expressions
|
||||
if is_heading || aggregator.is_some() {
|
||||
let entry = LineEntry {
|
||||
source: source.to_string(),
|
||||
parsed: None,
|
||||
parse_error: None,
|
||||
defines_var: None,
|
||||
references_vars: Vec::new(),
|
||||
is_text: is_heading,
|
||||
is_heading,
|
||||
aggregator,
|
||||
};
|
||||
self.lines.insert(index, entry);
|
||||
self.dirty_lines.insert(index);
|
||||
return;
|
||||
}
|
||||
|
||||
// Tokenize and parse through the real engine pipeline
|
||||
let tokens = lexer::tokenize(trimmed);
|
||||
|
||||
@@ -97,6 +123,8 @@ impl SheetContext {
|
||||
defines_var,
|
||||
references_vars,
|
||||
is_text,
|
||||
is_heading: false,
|
||||
aggregator: None,
|
||||
};
|
||||
|
||||
self.lines.insert(index, entry);
|
||||
@@ -125,6 +153,8 @@ impl SheetContext {
|
||||
/// This method performs dependency analysis and selective re-evaluation:
|
||||
/// - Lines whose dependencies haven't changed are not recomputed.
|
||||
/// - Circular dependencies are detected and reported as errors.
|
||||
/// - Line results are stored as `__line_N` variables for line references.
|
||||
/// - The `__prev` variable tracks the most recent numeric result for `prev`/`ans`.
|
||||
pub fn eval(&mut self) -> Vec<CalcResult> {
|
||||
let line_indices = self.sorted_line_indices();
|
||||
|
||||
@@ -151,6 +181,11 @@ impl SheetContext {
|
||||
// Build a shared EvalContext and evaluate in order.
|
||||
// We rebuild the context for the full pass so variables propagate correctly.
|
||||
let mut ctx = EvalContext::new();
|
||||
// Track subtotal values for grand total computation
|
||||
let mut subtotal_values: Vec<f64> = Vec::new();
|
||||
// Collect sources and results for aggregator section scanning
|
||||
let mut ordered_results: Vec<CalcResult> = Vec::new();
|
||||
let mut ordered_sources: Vec<String> = Vec::new();
|
||||
|
||||
for &idx in &line_indices {
|
||||
if circular_lines.contains(&idx) {
|
||||
@@ -158,39 +193,120 @@ impl SheetContext {
|
||||
"Circular dependency detected",
|
||||
Span::new(0, 1),
|
||||
);
|
||||
// Store line ref even for errors (as error)
|
||||
ctx.set_variable(&format!("__line_{}", idx + 1), result.clone());
|
||||
ordered_results.push(result.clone());
|
||||
ordered_sources.push(
|
||||
self.lines.get(&idx).map(|e| e.source.clone()).unwrap_or_default(),
|
||||
);
|
||||
self.results.insert(idx, result);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract needed fields from entry to avoid borrow conflicts
|
||||
let entry = &self.lines[&idx];
|
||||
let entry_source = entry.source.clone();
|
||||
let entry_is_heading = entry.is_heading;
|
||||
let entry_aggregator = entry.aggregator;
|
||||
let entry_is_text = entry.is_text;
|
||||
let entry_parsed = entry.parsed.clone();
|
||||
let entry_parse_error = entry.parse_error.clone();
|
||||
let entry_defines_var = entry.defines_var.clone();
|
||||
// Drop the borrow of self.lines
|
||||
drop(entry);
|
||||
|
||||
// Heading lines produce no result -- skip them
|
||||
if entry_is_heading {
|
||||
let result = CalcResult::error("no expression found", Span::new(0, entry_source.len()));
|
||||
ordered_results.push(result.clone());
|
||||
ordered_sources.push(entry_source);
|
||||
self.results.insert(idx, result);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Aggregator lines — compute section aggregation
|
||||
if let Some(agg_kind) = entry_aggregator {
|
||||
let span = Span::new(0, entry_source.trim().len());
|
||||
let result = if agg_kind == AggregatorKind::GrandTotal {
|
||||
aggregators::compute_grand_total(&subtotal_values, span)
|
||||
} else {
|
||||
let section_line_idx = ordered_results.len();
|
||||
let values = aggregators::collect_section_values(
|
||||
&ordered_results,
|
||||
&ordered_sources,
|
||||
section_line_idx,
|
||||
);
|
||||
let agg_result = aggregators::compute_aggregation(agg_kind, &values, span);
|
||||
|
||||
// Track subtotal values for grand total
|
||||
if agg_kind == AggregatorKind::Subtotal {
|
||||
if let crate::types::CalcValue::Number { value } = &agg_result.value {
|
||||
subtotal_values.push(*value);
|
||||
}
|
||||
}
|
||||
agg_result
|
||||
};
|
||||
|
||||
self.results.insert(idx, result.clone());
|
||||
// Store as __line_N for line reference support
|
||||
ctx.set_variable(&format!("__line_{}", idx + 1), result.clone());
|
||||
// Update __prev for prev/ans support
|
||||
if result.result_type() != crate::types::ResultType::Error {
|
||||
ctx.set_variable("__prev", result.clone());
|
||||
}
|
||||
ordered_results.push(result);
|
||||
ordered_sources.push(entry_source);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Text/empty lines produce no result -- skip them
|
||||
if entry.is_text || entry.source.trim().is_empty() {
|
||||
let result = CalcResult::error("no expression found", Span::new(0, entry.source.len()));
|
||||
if entry_is_text || entry_source.trim().is_empty() {
|
||||
let result = CalcResult::error("no expression found", Span::new(0, entry_source.len()));
|
||||
ordered_results.push(result.clone());
|
||||
ordered_sources.push(entry_source);
|
||||
self.results.insert(idx, result);
|
||||
continue;
|
||||
}
|
||||
|
||||
if lines_to_eval.contains(&idx) {
|
||||
// Evaluate this line
|
||||
if let Some(ref expr) = entry.parsed {
|
||||
if let Some(ref expr) = entry_parsed {
|
||||
let result = interpreter::evaluate(expr, &mut ctx);
|
||||
self.results.insert(idx, result);
|
||||
self.results.insert(idx, result.clone());
|
||||
// Store as __line_N for line reference support (1-indexed)
|
||||
ctx.set_variable(&format!("__line_{}", idx + 1), result.clone());
|
||||
// Update __prev for prev/ans support (only for non-error results)
|
||||
if result.result_type() != crate::types::ResultType::Error {
|
||||
ctx.set_variable("__prev", result.clone());
|
||||
}
|
||||
ordered_results.push(result);
|
||||
ordered_sources.push(entry_source);
|
||||
} else {
|
||||
// Parse error
|
||||
let err_msg = entry
|
||||
.parse_error
|
||||
let err_msg = entry_parse_error
|
||||
.as_deref()
|
||||
.unwrap_or("Parse error");
|
||||
let result = CalcResult::error(err_msg, Span::new(0, 1));
|
||||
ordered_results.push(result.clone());
|
||||
ordered_sources.push(entry_source);
|
||||
self.results.insert(idx, result);
|
||||
}
|
||||
} else {
|
||||
// Reuse cached result, but still replay variable definitions into ctx
|
||||
if let Some(cached) = self.results.get(&idx) {
|
||||
if let Some(ref var_name) = entry.defines_var {
|
||||
if let Some(cached) = self.results.get(&idx).cloned() {
|
||||
if let Some(ref var_name) = entry_defines_var {
|
||||
ctx.set_variable(var_name, cached.clone());
|
||||
}
|
||||
// Replay line ref and prev for cached results too
|
||||
ctx.set_variable(&format!("__line_{}", idx + 1), cached.clone());
|
||||
if cached.result_type() != crate::types::ResultType::Error {
|
||||
ctx.set_variable("__prev", cached.clone());
|
||||
}
|
||||
ordered_results.push(cached);
|
||||
ordered_sources.push(entry_source);
|
||||
} else {
|
||||
ordered_results.push(CalcResult::error("No result", Span::new(0, 1)));
|
||||
ordered_sources.push(entry_source);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -240,6 +356,7 @@ impl SheetContext {
|
||||
}
|
||||
|
||||
/// Build a map from variable name to the line index that defines it.
|
||||
/// Also maps `__line_N` references to the correct 0-based line index.
|
||||
fn build_var_to_line_map(&self, line_indices: &[usize]) -> HashMap<String, usize> {
|
||||
let mut map = HashMap::new();
|
||||
for &idx in line_indices {
|
||||
@@ -248,7 +365,13 @@ impl SheetContext {
|
||||
map.insert(var_name.clone(), idx);
|
||||
}
|
||||
}
|
||||
// Map __line_N (1-indexed) to line index (0-indexed)
|
||||
map.insert(format!("__line_{}", idx + 1), idx);
|
||||
}
|
||||
// __prev dependencies are handled dynamically during evaluation:
|
||||
// each line referencing __prev depends on the most recent line that produced
|
||||
// a non-error result. Since evaluation is always in line order, this works
|
||||
// without explicit dependency tracking.
|
||||
map
|
||||
}
|
||||
|
||||
@@ -406,6 +529,23 @@ fn collect_references(node: &ExprKind, vars: &mut Vec<String>) {
|
||||
collect_references(&left.node, vars);
|
||||
collect_references(&right.node, vars);
|
||||
}
|
||||
ExprKind::LineRef(line_num) => {
|
||||
let key = format!("__line_{}", line_num);
|
||||
if !vars.contains(&key) {
|
||||
vars.push(key);
|
||||
}
|
||||
}
|
||||
ExprKind::PrevRef => {
|
||||
let key = "__prev".to_string();
|
||||
if !vars.contains(&key) {
|
||||
vars.push(key);
|
||||
}
|
||||
}
|
||||
ExprKind::FunctionCall { args, .. } => {
|
||||
for arg in args {
|
||||
collect_references(&arg.node, vars);
|
||||
}
|
||||
}
|
||||
// Leaf nodes with no variable references
|
||||
ExprKind::Number(_)
|
||||
| ExprKind::UnitNumber { .. }
|
||||
@@ -680,4 +820,259 @@ mod tests {
|
||||
assert!(!dirty.contains(&0), "Line 0 should NOT be dirty");
|
||||
assert!(!dirty.contains(&2), "Line 2 should NOT be dirty (doesn't depend on b)");
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Line References (#N and lineN)
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_line_ref_hash_syntax() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "100"); // line 1
|
||||
ctx.set_line(1, "#1 * 2"); // refers to line 1
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 100.0 });
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 200.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_line_ref_line_syntax() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "50"); // line 1
|
||||
ctx.set_line(1, "line1 + 10"); // refers to line 1
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 50.0 });
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 60.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_line_ref_invalid_out_of_range() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "#99 + 5");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 10.0 });
|
||||
assert_eq!(results[1].result_type(), ResultType::Error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_line_ref_with_variables() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "x = 10"); // line 1, x = 10
|
||||
ctx.set_line(1, "20"); // line 2 = 20
|
||||
ctx.set_line(2, "x + #2"); // x (10) + line2 (20) = 30
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 10.0 });
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 20.0 });
|
||||
assert_eq!(results[2].value, CalcValue::Number { value: 30.0 });
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Previous Line Reference (prev / ans)
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_prev_basic() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "50");
|
||||
ctx.set_line(1, "prev * 2");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 50.0 });
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 100.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ans_basic() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "50");
|
||||
ctx.set_line(1, "ans + 10");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 50.0 });
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 60.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prev_on_first_line_is_error() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "prev * 2");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].result_type(), ResultType::Error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prev_chain() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "prev + 5"); // 15
|
||||
ctx.set_line(2, "prev * 2"); // 30
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 10.0 });
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 15.0 });
|
||||
assert_eq!(results[2].value, CalcValue::Number { value: 30.0 });
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Function Calls
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_sqrt_function() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "sqrt(16)");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 4.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_abs_function() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "abs(-5)");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 5.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_function() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "round(3.7)");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[0].value, CalcValue::Number { value: 4.0 });
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Aggregators
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_sum_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "20");
|
||||
ctx.set_line(2, "30");
|
||||
ctx.set_line(3, "40");
|
||||
ctx.set_line(4, "sum");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[4].value, CalcValue::Number { value: 100.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_total_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "20");
|
||||
ctx.set_line(2, "30");
|
||||
ctx.set_line(3, "total");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[3].value, CalcValue::Number { value: 60.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_average_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "20");
|
||||
ctx.set_line(2, "30");
|
||||
ctx.set_line(3, "average");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[3].value, CalcValue::Number { value: 20.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "5");
|
||||
ctx.set_line(1, "12");
|
||||
ctx.set_line(2, "3");
|
||||
ctx.set_line(3, "8");
|
||||
ctx.set_line(4, "min");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[4].value, CalcValue::Number { value: 3.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "5");
|
||||
ctx.set_line(1, "12");
|
||||
ctx.set_line(2, "3");
|
||||
ctx.set_line(3, "8");
|
||||
ctx.set_line(4, "max");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[4].value, CalcValue::Number { value: 12.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "5");
|
||||
ctx.set_line(1, "12");
|
||||
ctx.set_line(2, "3");
|
||||
ctx.set_line(3, "8");
|
||||
ctx.set_line(4, "count");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[4].value, CalcValue::Number { value: 4.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregator_with_heading_section() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "20");
|
||||
ctx.set_line(2, "## Monthly Costs");
|
||||
ctx.set_line(3, "100");
|
||||
ctx.set_line(4, "200");
|
||||
ctx.set_line(5, "sum");
|
||||
let results = ctx.eval();
|
||||
// sum should only include lines 3 and 4 (after heading), not lines 0 and 1
|
||||
assert_eq!(results[5].value, CalcValue::Number { value: 300.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_section_aggregator() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "## Empty Section");
|
||||
ctx.set_line(1, "sum");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[1].value, CalcValue::Number { value: 0.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subtotal_and_grand_total() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "## Section A");
|
||||
ctx.set_line(1, "100");
|
||||
ctx.set_line(2, "200");
|
||||
ctx.set_line(3, "subtotal"); // 300
|
||||
ctx.set_line(4, "## Section B");
|
||||
ctx.set_line(5, "50");
|
||||
ctx.set_line(6, "75");
|
||||
ctx.set_line(7, "subtotal"); // 125
|
||||
ctx.set_line(8, "grand total"); // 300 + 125 = 425
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[3].value, CalcValue::Number { value: 300.0 });
|
||||
assert_eq!(results[7].value, CalcValue::Number { value: 125.0 });
|
||||
assert_eq!(results[8].value, CalcValue::Number { value: 425.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregator_skips_comments_and_errors() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "// This is a comment");
|
||||
ctx.set_line(2, "20");
|
||||
ctx.set_line(3, "sum");
|
||||
let results = ctx.eval();
|
||||
// sum should include 10 and 20, skipping the comment
|
||||
assert_eq!(results[3].value, CalcValue::Number { value: 30.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregator_case_insensitive() {
|
||||
let mut ctx = SheetContext::new();
|
||||
ctx.set_line(0, "10");
|
||||
ctx.set_line(1, "20");
|
||||
ctx.set_line(2, "SUM");
|
||||
let results = ctx.eval();
|
||||
assert_eq!(results[2].value, CalcValue::Number { value: 30.0 });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,10 @@ pub enum TokenKind {
|
||||
NotEqual,
|
||||
/// Assignment `=`.
|
||||
Assign,
|
||||
/// Line reference: `line1`, `#1` (stores the 1-indexed line number).
|
||||
LineRef(usize),
|
||||
/// Previous-line reference: `prev` or `ans`.
|
||||
PrevRef,
|
||||
/// A generic keyword (discount, off, etc.).
|
||||
Keyword(String),
|
||||
/// A comment token.
|
||||
|
||||
93
calcpad-engine/src/units/categories.rs
Normal file
93
calcpad-engine/src/units/categories.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
//! Unit categories that determine which units can convert to each other.
|
||||
|
||||
/// Unit category -- determines which units can convert to each other.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum UnitCategory {
|
||||
Length,
|
||||
Mass,
|
||||
Volume,
|
||||
Area,
|
||||
Speed,
|
||||
Temperature,
|
||||
Data,
|
||||
Angle,
|
||||
Time,
|
||||
Pressure,
|
||||
Energy,
|
||||
Power,
|
||||
Force,
|
||||
CssScreen,
|
||||
}
|
||||
|
||||
impl UnitCategory {
|
||||
/// Return all standard categories (excludes CssScreen which is special).
|
||||
pub fn all() -> &'static [UnitCategory] {
|
||||
&[
|
||||
UnitCategory::Length,
|
||||
UnitCategory::Mass,
|
||||
UnitCategory::Volume,
|
||||
UnitCategory::Area,
|
||||
UnitCategory::Speed,
|
||||
UnitCategory::Temperature,
|
||||
UnitCategory::Data,
|
||||
UnitCategory::Angle,
|
||||
UnitCategory::Time,
|
||||
UnitCategory::Pressure,
|
||||
UnitCategory::Energy,
|
||||
UnitCategory::Power,
|
||||
UnitCategory::Force,
|
||||
UnitCategory::CssScreen,
|
||||
]
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
UnitCategory::Length => "length",
|
||||
UnitCategory::Mass => "mass",
|
||||
UnitCategory::Volume => "volume",
|
||||
UnitCategory::Area => "area",
|
||||
UnitCategory::Speed => "speed",
|
||||
UnitCategory::Temperature => "temperature",
|
||||
UnitCategory::Data => "data",
|
||||
UnitCategory::Angle => "angle",
|
||||
UnitCategory::Time => "time",
|
||||
UnitCategory::Pressure => "pressure",
|
||||
UnitCategory::Energy => "energy",
|
||||
UnitCategory::Power => "power",
|
||||
UnitCategory::Force => "force",
|
||||
UnitCategory::CssScreen => "css/screen",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UnitCategory {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_all_returns_14_categories() {
|
||||
assert_eq!(UnitCategory::all().len(), 14);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_category_names_are_unique() {
|
||||
let names: Vec<&str> = UnitCategory::all().iter().map(|c| c.name()).collect();
|
||||
let mut deduped = names.clone();
|
||||
deduped.sort();
|
||||
deduped.dedup();
|
||||
assert_eq!(names.len(), deduped.len(), "Category names must be unique");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_matches_name() {
|
||||
for cat in UnitCategory::all() {
|
||||
assert_eq!(format!("{}", cat), cat.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
168
calcpad-engine/src/units/css.rs
Normal file
168
calcpad-engine/src/units/css.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
//! CSS and screen units with configurable PPI and em base size.
|
||||
//!
|
||||
//! CSS units (px, pt, em, rem, pica) are registered in the global registry with
|
||||
//! default values (PPI=96, em=16px). For accurate conversions at non-standard
|
||||
//! display densities, use `convert_css()` with a custom `CssConfig`.
|
||||
|
||||
use super::categories::UnitCategory;
|
||||
use super::{Conversion, UnitDef, UnitRegistry};
|
||||
|
||||
/// Configuration for CSS/screen unit conversions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CssConfig {
|
||||
/// Pixels per inch. Default: 96.
|
||||
pub ppi: f64,
|
||||
/// Base font size in pixels (for em/rem). Default: 16.
|
||||
pub em_base_px: f64,
|
||||
}
|
||||
|
||||
impl Default for CssConfig {
|
||||
fn default() -> Self {
|
||||
CssConfig {
|
||||
ppi: 96.0,
|
||||
em_base_px: 16.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CssConfig {
|
||||
/// Get the conversion factor from a CSS unit to pixels.
|
||||
fn to_px_factor(&self, unit_name: &str) -> Option<f64> {
|
||||
let lower = unit_name.to_lowercase();
|
||||
match lower.as_str() {
|
||||
"px" | "pixel" | "pixels" => Some(1.0),
|
||||
"pt" | "point" | "points" => Some(self.ppi / 72.0),
|
||||
"em" | "ems" => Some(self.em_base_px),
|
||||
"rem" | "rems" => Some(self.em_base_px),
|
||||
"pc" | "pica" | "picas" => Some(12.0 * self.ppi / 72.0),
|
||||
"dppx" => Some(1.0),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert between CSS/screen units with configurable PPI and em base.
|
||||
pub fn convert_css(value: f64, from: &str, to: &str, config: &CssConfig) -> Result<f64, String> {
|
||||
let from_factor = config
|
||||
.to_px_factor(from)
|
||||
.ok_or_else(|| format!("Unknown CSS unit: {}", from))?;
|
||||
let to_factor = config
|
||||
.to_px_factor(to)
|
||||
.ok_or_else(|| format!("Unknown CSS unit: {}", to))?;
|
||||
|
||||
// Convert: from -> px -> to
|
||||
let px_value = value * from_factor;
|
||||
Ok(px_value / to_factor)
|
||||
}
|
||||
|
||||
/// Register CSS/screen units in the registry.
|
||||
/// These are registered with DEFAULT factors (PPI=96, em=16px).
|
||||
/// Actual conversions should use `convert_css()` for runtime config support.
|
||||
pub(crate) fn register_css_screen(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::CssScreen;
|
||||
|
||||
fn linear(reg: &mut UnitRegistry, name: &'static str, abbrev: &'static str, factor: f64, aliases: &[&str]) {
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name,
|
||||
abbreviation: abbrev,
|
||||
category: UnitCategory::CssScreen,
|
||||
conversion: Conversion::Linear(factor),
|
||||
},
|
||||
aliases,
|
||||
);
|
||||
}
|
||||
|
||||
// px is the base unit (factor = 1.0)
|
||||
linear(reg, "pixel", "px", 1.0, &["pixels"]);
|
||||
// pt: 1pt = PPI/72 px. At default PPI=96: 96/72 = 4/3
|
||||
linear(reg, "point", "pt", 96.0 / 72.0, &["points"]);
|
||||
// em: 1em = em_base px. Default em_base=16
|
||||
linear(reg, "em", "em", 16.0, &["ems"]);
|
||||
// rem: 1rem = em_base px. Default em_base=16
|
||||
linear(reg, "rem", "rem", 16.0, &["rems"]);
|
||||
// pc (pica): 1pc = 12pt = 12 * PPI/72 px. At default PPI=96: 16px
|
||||
linear(reg, "pica", "pica", 12.0 * 96.0 / 72.0, &["picas"]);
|
||||
// dppx (dots per pixel): device pixel ratio unit, base = 1.0
|
||||
linear(reg, "dppx", "dppx", 1.0, &[]);
|
||||
|
||||
// Suppress unused variable warning
|
||||
let _ = c;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = CssConfig::default();
|
||||
assert_eq!(config.ppi, 96.0);
|
||||
assert_eq!(config.em_base_px, 16.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_12pt_to_px_default() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(12.0, "pt", "px", &config).unwrap();
|
||||
assert!((result - 16.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_2em_to_px_default() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(2.0, "em", "px", &config).unwrap();
|
||||
assert!((result - 32.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1rem_to_px_default() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(1.0, "rem", "px", &config).unwrap();
|
||||
assert!((result - 16.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_96px_to_pt_default() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(96.0, "px", "pt", &config).unwrap();
|
||||
assert!((result - 72.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_12pt_to_px_retina() {
|
||||
let config = CssConfig { ppi: 326.0, em_base_px: 16.0 };
|
||||
let result = convert_css(12.0, "pt", "px", &config).unwrap();
|
||||
let expected = 12.0 * 326.0 / 72.0;
|
||||
assert!((result - expected).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_2em_custom_base() {
|
||||
let config = CssConfig { ppi: 96.0, em_base_px: 20.0 };
|
||||
let result = convert_css(2.0, "em", "px", &config).unwrap();
|
||||
assert!((result - 40.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_css_unit() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(1.0, "frobbles", "px", &config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_px_identity() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(42.0, "px", "px", &config).unwrap();
|
||||
assert!((result - 42.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pica_to_px() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_css(1.0, "pica", "px", &config).unwrap();
|
||||
// 1 pica = 12 pt = 12 * 96/72 = 16 px at default PPI
|
||||
assert!((result - 16.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
406
calcpad-engine/src/units/custom.rs
Normal file
406
calcpad-engine/src/units/custom.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
//! Custom user-defined units.
|
||||
//!
|
||||
//! Allows users to define their own units in terms of existing units (built-in or
|
||||
//! previously defined custom units). Supports chaining, circular dependency detection,
|
||||
//! auto-generated plural aliases, and warnings when shadowing built-in units.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use super::categories::UnitCategory;
|
||||
|
||||
/// A custom user-defined unit entry.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CustomUnitDef {
|
||||
/// Canonical name (e.g., "sprint").
|
||||
pub name: String,
|
||||
/// Conversion factor: 1 custom_unit = factor * base_unit (in category base).
|
||||
pub factor: f64,
|
||||
/// The base unit's canonical name (e.g., "second" for time units).
|
||||
pub base_unit_name: String,
|
||||
/// The category inherited from the base unit.
|
||||
pub category: UnitCategory,
|
||||
}
|
||||
|
||||
/// Result of registering a custom unit.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RegisterResult {
|
||||
/// Warning message if this unit shadows a built-in.
|
||||
pub warning: Option<String>,
|
||||
}
|
||||
|
||||
/// Registry for custom user-defined units.
|
||||
///
|
||||
/// Supports registration, lookup, alias generation (plural forms),
|
||||
/// circular dependency detection, and built-in shadowing warnings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CustomUnitRegistry {
|
||||
/// Maps lowercase name/alias -> custom unit definition.
|
||||
units: HashMap<String, CustomUnitDef>,
|
||||
}
|
||||
|
||||
impl CustomUnitRegistry {
|
||||
pub fn new() -> Self {
|
||||
CustomUnitRegistry {
|
||||
units: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a custom unit definition.
|
||||
///
|
||||
/// `name`: the new unit name (e.g., "sprint")
|
||||
/// `factor`: how many base units equal one custom unit (e.g., 2.0 for "1 sprint = 2 weeks")
|
||||
/// `base_unit_name`: the name of the base unit as written (e.g., "weeks")
|
||||
///
|
||||
/// The base unit must resolve to either a built-in unit or a previously-registered
|
||||
/// custom unit. Returns an error for circular dependencies or unresolvable base units.
|
||||
pub fn register(
|
||||
&mut self,
|
||||
name: &str,
|
||||
factor: f64,
|
||||
base_unit_name: &str,
|
||||
) -> Result<RegisterResult, String> {
|
||||
// Resolve the base unit -- could be built-in or another custom unit
|
||||
let (resolved_factor, canonical_base, category) =
|
||||
self.resolve_base_chain(base_unit_name, name)?;
|
||||
|
||||
let total_factor = factor * resolved_factor;
|
||||
|
||||
let def = CustomUnitDef {
|
||||
name: name.to_string(),
|
||||
factor: total_factor,
|
||||
base_unit_name: canonical_base,
|
||||
category,
|
||||
};
|
||||
|
||||
// Check for built-in shadowing
|
||||
let warning = {
|
||||
let reg = super::registry();
|
||||
if reg.lookup(name).is_some() {
|
||||
Some(format!(
|
||||
"Custom unit '{}' shadows built-in unit '{}'",
|
||||
name, name
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Register the unit under its canonical name
|
||||
self.units.insert(name.to_lowercase(), def.clone());
|
||||
|
||||
// Auto-generate plural alias: add "s" if name doesn't end with "s"
|
||||
let lower_name = name.to_lowercase();
|
||||
if !lower_name.ends_with('s') {
|
||||
let plural = format!("{}s", lower_name);
|
||||
self.units.insert(plural, def);
|
||||
}
|
||||
|
||||
Ok(RegisterResult { warning })
|
||||
}
|
||||
|
||||
/// Look up a custom unit by name or alias.
|
||||
pub fn lookup(&self, name: &str) -> Option<&CustomUnitDef> {
|
||||
self.units.get(&name.to_lowercase())
|
||||
}
|
||||
|
||||
/// Get all registered custom unit names (including aliases).
|
||||
pub fn unit_names(&self) -> HashSet<String> {
|
||||
self.units.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Convert a value from a custom unit to a target unit.
|
||||
///
|
||||
/// First converts to the category's base unit using the custom unit's
|
||||
/// total factor, then converts from base to target using the built-in registry.
|
||||
pub fn convert(
|
||||
&self,
|
||||
value: f64,
|
||||
from: &str,
|
||||
to: &str,
|
||||
) -> Result<f64, String> {
|
||||
let from_custom = self.lookup(from);
|
||||
let to_custom = self.lookup(to);
|
||||
|
||||
match (from_custom, to_custom) {
|
||||
(Some(from_def), Some(to_def)) => {
|
||||
// Both are custom units -- must be same category
|
||||
if from_def.category != to_def.category {
|
||||
return Err(format!(
|
||||
"Cannot convert between '{}' ({}) and '{}' ({})",
|
||||
from, from_def.category, to, to_def.category
|
||||
));
|
||||
}
|
||||
let base_value = self.custom_to_category_base(value, from_def)?;
|
||||
let result = self.category_base_to_custom(base_value, to_def)?;
|
||||
Ok(result)
|
||||
}
|
||||
(Some(from_def), None) => {
|
||||
// Source is custom, target is built-in
|
||||
let base_value = self.custom_to_category_base(value, from_def)?;
|
||||
let reg = super::registry();
|
||||
let to_resolved = reg
|
||||
.resolve_with_prefix(to)
|
||||
.ok_or_else(|| format!("Unknown unit: {}", to))?;
|
||||
if to_resolved.unit.category != from_def.category {
|
||||
return Err(format!(
|
||||
"Cannot convert between '{}' ({}) and '{}' ({})",
|
||||
from, from_def.category, to_resolved.unit.name, to_resolved.unit.category
|
||||
));
|
||||
}
|
||||
Ok(to_resolved.from_base(base_value))
|
||||
}
|
||||
(None, Some(to_def)) => {
|
||||
// Source is built-in, target is custom
|
||||
let reg = super::registry();
|
||||
let from_resolved = reg
|
||||
.resolve_with_prefix(from)
|
||||
.ok_or_else(|| format!("Unknown unit: {}", from))?;
|
||||
if from_resolved.unit.category != to_def.category {
|
||||
return Err(format!(
|
||||
"Cannot convert between '{}' ({}) and '{}' ({})",
|
||||
from_resolved.unit.name, from_resolved.unit.category, to, to_def.category
|
||||
));
|
||||
}
|
||||
let base_value = from_resolved.to_base(value);
|
||||
let result = self.category_base_to_custom(base_value, to_def)?;
|
||||
Ok(result)
|
||||
}
|
||||
(None, None) => {
|
||||
// Neither is custom -- delegate to built-in conversion
|
||||
super::convert(value, from, to)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a value in a custom unit to the category's base unit.
|
||||
fn custom_to_category_base(&self, value: f64, def: &CustomUnitDef) -> Result<f64, String> {
|
||||
let reg = super::registry();
|
||||
let base_unit = reg
|
||||
.resolve_with_prefix(&def.base_unit_name)
|
||||
.ok_or_else(|| format!("Base unit '{}' not found", def.base_unit_name))?;
|
||||
Ok(base_unit.to_base(value * def.factor))
|
||||
}
|
||||
|
||||
/// Convert a value from the category's base unit to a custom unit.
|
||||
fn category_base_to_custom(&self, base_value: f64, def: &CustomUnitDef) -> Result<f64, String> {
|
||||
let reg = super::registry();
|
||||
let base_unit = reg
|
||||
.resolve_with_prefix(&def.base_unit_name)
|
||||
.ok_or_else(|| format!("Base unit '{}' not found", def.base_unit_name))?;
|
||||
let value_in_base_unit = base_unit.from_base(base_value);
|
||||
Ok(value_in_base_unit / def.factor)
|
||||
}
|
||||
|
||||
/// Resolve a base unit chain, detecting circular dependencies.
|
||||
fn resolve_base_chain(
|
||||
&self,
|
||||
base_name: &str,
|
||||
defining_name: &str,
|
||||
) -> Result<(f64, String, UnitCategory), String> {
|
||||
let mut visited: HashSet<String> = HashSet::new();
|
||||
visited.insert(defining_name.to_lowercase());
|
||||
self.resolve_base_chain_inner(base_name, &mut visited)
|
||||
}
|
||||
|
||||
fn resolve_base_chain_inner(
|
||||
&self,
|
||||
base_name: &str,
|
||||
visited: &mut HashSet<String>,
|
||||
) -> Result<(f64, String, UnitCategory), String> {
|
||||
let lower = base_name.to_lowercase();
|
||||
|
||||
// Check for circular dependency
|
||||
if visited.contains(&lower) {
|
||||
return Err(format!(
|
||||
"Circular dependency detected in custom unit definitions involving '{}'",
|
||||
base_name
|
||||
));
|
||||
}
|
||||
|
||||
// Try built-in registry first
|
||||
let reg = super::registry();
|
||||
if let Some(resolved) = reg.resolve_with_prefix(base_name) {
|
||||
return Ok((
|
||||
resolved.prefix_factor,
|
||||
resolved.unit.name.to_string(),
|
||||
resolved.unit.category,
|
||||
));
|
||||
}
|
||||
|
||||
// Try custom unit registry
|
||||
if let Some(custom) = self.units.get(&lower) {
|
||||
visited.insert(lower);
|
||||
let (chain_factor, canonical, category) =
|
||||
self.resolve_base_chain_inner(&custom.base_unit_name, visited)?;
|
||||
Ok((custom.factor * chain_factor, canonical, category))
|
||||
} else {
|
||||
Err(format!("Unknown base unit: '{}'", base_name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CustomUnitRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_register_custom_unit() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
let result = reg.register("sprint", 2.0, "weeks").unwrap();
|
||||
assert!(result.warning.is_none());
|
||||
|
||||
let def = reg.lookup("sprint").unwrap();
|
||||
assert_eq!(def.name, "sprint");
|
||||
assert_eq!(def.category, UnitCategory::Time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plural_alias() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("sprint", 2.0, "weeks").unwrap();
|
||||
assert!(reg.lookup("sprints").is_some());
|
||||
assert!(reg.lookup("sprint").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_duplicate_plural_for_s_ending() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("kudos", 1.0, "hours").unwrap();
|
||||
assert!(reg.lookup("kudos").is_some());
|
||||
assert!(reg.lookup("kudoss").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_custom_to_builtin() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("sprint", 2.0, "weeks").unwrap();
|
||||
|
||||
// 3 sprints = 6 weeks = 42 days
|
||||
let result = reg.convert(3.0, "sprints", "days").unwrap();
|
||||
assert!(
|
||||
(result - 42.0).abs() < 1e-6,
|
||||
"Expected 42.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_custom_story_points() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("story_point", 4.0, "hours").unwrap();
|
||||
|
||||
// 10 story_points = 40 hours
|
||||
let result = reg.convert(10.0, "story_points", "hours").unwrap();
|
||||
assert!(
|
||||
(result - 40.0).abs() < 1e-6,
|
||||
"Expected 40.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chained_custom_units() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("sprint", 2.0, "weeks").unwrap();
|
||||
reg.register("quarter", 6.0, "sprints").unwrap();
|
||||
|
||||
// 1 quarter = 6 sprints = 12 weeks = 84 days
|
||||
let result = reg.convert(1.0, "quarter", "days").unwrap();
|
||||
assert!(
|
||||
(result - 84.0).abs() < 1e-6,
|
||||
"Expected 84.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_dependency_self_reference() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
let result = reg.register("foo", 2.0, "foo");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Circular dependency"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_base_unit() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
let result = reg.register("foo", 2.0, "nonexistent");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Unknown base unit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builtin_shadowing_warning() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
let result = reg.register("meter", 100.0, "cm").unwrap();
|
||||
assert!(result.warning.is_some());
|
||||
assert!(result.warning.unwrap().contains("shadows"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive_lookup() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("Sprint", 2.0, "weeks").unwrap();
|
||||
assert!(reg.lookup("sprint").is_some());
|
||||
assert!(reg.lookup("SPRINT").is_some());
|
||||
assert!(reg.lookup("Sprint").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chaining_is_ok() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("foo", 2.0, "hours").unwrap();
|
||||
let result = reg.register("bar", 3.0, "foo");
|
||||
assert!(result.is_ok());
|
||||
let result = reg.register("baz", 1.0, "bar");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builtin_to_custom_conversion() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("sprint", 2.0, "weeks").unwrap();
|
||||
|
||||
// 42 days = 3 sprints
|
||||
let result = reg.convert(42.0, "days", "sprints").unwrap();
|
||||
assert!(
|
||||
(result - 3.0).abs() < 1e-6,
|
||||
"Expected 3.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_to_custom_conversion() {
|
||||
let mut reg = CustomUnitRegistry::new();
|
||||
reg.register("sprint", 2.0, "weeks").unwrap();
|
||||
reg.register("milestone", 4.0, "weeks").unwrap();
|
||||
|
||||
// 1 milestone = 4 weeks, 1 sprint = 2 weeks
|
||||
// So 1 milestone = 2 sprints
|
||||
let result = reg.convert(1.0, "milestones", "sprints").unwrap();
|
||||
assert!(
|
||||
(result - 2.0).abs() < 1e-6,
|
||||
"Expected 2.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_to_builtin_conversion() {
|
||||
let reg = CustomUnitRegistry::new();
|
||||
// Neither unit is custom -- should delegate to built-in
|
||||
let result = reg.convert(5.0, "km", "miles").unwrap();
|
||||
assert!(
|
||||
(result - 3.10686).abs() < 1e-4,
|
||||
"Expected ~3.10686, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
}
|
||||
175
calcpad-engine/src/units/data.rs
Normal file
175
calcpad-engine/src/units/data.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
//! Data units with proper binary vs decimal distinction and case-sensitive aliases.
|
||||
//!
|
||||
//! - Decimal (SI): KB, MB, GB, TB, PB, EB (powers of 1000)
|
||||
//! - Binary (IEC): KiB, MiB, GiB, TiB, PiB (powers of 1024)
|
||||
//! - Case-sensitive: "B" = byte, "b" = bit, "MB" = megabyte, "Mb" = megabit
|
||||
|
||||
use super::categories::UnitCategory;
|
||||
use super::{Conversion, UnitDef, UnitRegistry};
|
||||
|
||||
/// Helper to register a linear data unit.
|
||||
fn linear(
|
||||
reg: &mut UnitRegistry,
|
||||
name: &'static str,
|
||||
abbrev: &'static str,
|
||||
factor: f64,
|
||||
aliases: &[&str],
|
||||
) {
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name,
|
||||
abbreviation: abbrev,
|
||||
category: UnitCategory::Data,
|
||||
conversion: Conversion::Linear(factor),
|
||||
},
|
||||
aliases,
|
||||
);
|
||||
}
|
||||
|
||||
/// Register all data units and case-sensitive aliases.
|
||||
pub(crate) fn register_data(reg: &mut UnitRegistry) {
|
||||
// Base units
|
||||
linear(reg, "bit", "bit", 0.125, &["bits"]);
|
||||
linear(reg, "byte", "B", 1.0, &["bytes"]);
|
||||
|
||||
// Decimal byte units (powers of 1000)
|
||||
linear(reg, "kilobyte", "KB", 1000.0, &["kilobytes"]);
|
||||
linear(reg, "megabyte", "MB", 1e6, &["megabytes"]);
|
||||
linear(reg, "gigabyte", "GB", 1e9, &["gigabytes"]);
|
||||
linear(reg, "terabyte", "TB", 1e12, &["terabytes"]);
|
||||
linear(reg, "petabyte", "PB", 1e15, &["petabytes"]);
|
||||
linear(reg, "exabyte", "EB", 1e18, &["exabytes"]);
|
||||
|
||||
// Binary byte units (powers of 1024)
|
||||
linear(reg, "kibibyte", "KiB", 1024.0, &["kibibytes"]);
|
||||
linear(reg, "mebibyte", "MiB", 1_048_576.0, &["mebibytes"]);
|
||||
linear(reg, "gibibyte", "GiB", 1_073_741_824.0, &["gibibytes"]);
|
||||
linear(reg, "tebibyte", "TiB", 1_099_511_627_776.0, &["tebibytes"]);
|
||||
linear(reg, "pebibyte", "PiB", 1_125_899_906_842_624.0, &["pebibytes"]);
|
||||
|
||||
// Decimal bit units (powers of 1000, in bytes: factor / 8)
|
||||
linear(reg, "kilobit", "Kbit", 125.0, &["kilobits"]);
|
||||
linear(reg, "megabit", "Mbit", 125_000.0, &["megabits"]);
|
||||
linear(reg, "gigabit", "Gbit", 125_000_000.0, &["gigabits"]);
|
||||
linear(reg, "terabit", "Tbit", 125_000_000_000.0, &["terabits"]);
|
||||
|
||||
// Case-sensitive aliases: uppercase B = bytes, lowercase b = bits.
|
||||
// Without these, "MB" and "Mb" both lowercase to "mb" -> megabyte (wrong for megabit).
|
||||
reg.register_case_sensitive_alias("b", "bit");
|
||||
reg.register_case_sensitive_alias("B", "byte");
|
||||
|
||||
reg.register_case_sensitive_alias("Kb", "kilobit");
|
||||
reg.register_case_sensitive_alias("Mb", "megabit");
|
||||
reg.register_case_sensitive_alias("Gb", "gigabit");
|
||||
reg.register_case_sensitive_alias("Tb", "terabit");
|
||||
|
||||
reg.register_case_sensitive_alias("KB", "kilobyte");
|
||||
reg.register_case_sensitive_alias("MB", "megabyte");
|
||||
reg.register_case_sensitive_alias("GB", "gigabyte");
|
||||
reg.register_case_sensitive_alias("TB", "terabyte");
|
||||
reg.register_case_sensitive_alias("PB", "petabyte");
|
||||
reg.register_case_sensitive_alias("EB", "exabyte");
|
||||
|
||||
reg.register_case_sensitive_alias("KiB", "kibibyte");
|
||||
reg.register_case_sensitive_alias("MiB", "mebibyte");
|
||||
reg.register_case_sensitive_alias("GiB", "gibibyte");
|
||||
reg.register_case_sensitive_alias("TiB", "tebibyte");
|
||||
reg.register_case_sensitive_alias("PiB", "pebibyte");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::convert;
|
||||
|
||||
// ─── Decimal units ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_decimal_mb_to_kb() {
|
||||
let result = convert(1.0, "MB", "KB").unwrap();
|
||||
assert!((result - 1000.0).abs() < 1e-10, "Expected 1000, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decimal_gb_to_mb() {
|
||||
let result = convert(1.0, "GB", "MB").unwrap();
|
||||
assert!((result - 1000.0).abs() < 1e-10, "Expected 1000, got {}", result);
|
||||
}
|
||||
|
||||
// ─── Binary units ───────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_binary_mib_to_kib() {
|
||||
let result = convert(1.0, "MiB", "KiB").unwrap();
|
||||
assert!((result - 1024.0).abs() < 1e-10, "Expected 1024, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_gib_to_mib() {
|
||||
let result = convert(1.0, "GiB", "MiB").unwrap();
|
||||
assert!((result - 1024.0).abs() < 1e-10, "Expected 1024, got {}", result);
|
||||
}
|
||||
|
||||
// ─── Bit/byte distinction ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_byte_to_bits() {
|
||||
let result = convert(1.0, "byte", "bits").unwrap();
|
||||
assert!((result - 8.0).abs() < 1e-10, "Expected 8, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_megabytes_to_megabits_case_sensitive() {
|
||||
let result = convert(1.0, "MB", "Mb").unwrap();
|
||||
assert!((result - 8.0).abs() < 1e-10, "Expected 8, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kilobytes_to_kilobits() {
|
||||
let result = convert(1.0, "KB", "Kb").unwrap();
|
||||
assert!((result - 8.0).abs() < 1e-10, "Expected 8, got {}", result);
|
||||
}
|
||||
|
||||
// ─── Cross-conversion (binary <-> decimal) ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_cross_convert_gib_to_gb() {
|
||||
let result = convert(5.0, "GiB", "GB").unwrap();
|
||||
let expected = 5.0 * 1_073_741_824.0 / 1e9;
|
||||
assert!((result - expected).abs() < 1e-6, "Expected {}, got {}", expected, result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_convert_tb_to_tib() {
|
||||
let result = convert(1.0, "TB", "TiB").unwrap();
|
||||
let expected = 1e12 / 1_099_511_627_776.0;
|
||||
assert!((result - expected).abs() < 1e-6, "Expected {}, got {}", expected, result);
|
||||
}
|
||||
|
||||
// ─── Case-sensitive lookups ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_case_sensitive_b_vs_big_b() {
|
||||
let reg = super::super::registry();
|
||||
let big_b = reg.lookup("B").unwrap();
|
||||
assert_eq!(big_b.name, "byte");
|
||||
let small_b = reg.lookup("b").unwrap();
|
||||
assert_eq!(small_b.name, "bit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_sensitive_mb_vs_big_mb() {
|
||||
let reg = super::super::registry();
|
||||
let big_mb = reg.lookup("MB").unwrap();
|
||||
assert_eq!(big_mb.name, "megabyte");
|
||||
let small_mb = reg.lookup("Mb").unwrap();
|
||||
assert_eq!(small_mb.name, "megabit");
|
||||
}
|
||||
|
||||
// ─── Edge cases ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_zero_data_conversion() {
|
||||
let result = convert(0.0, "MB", "KB").unwrap();
|
||||
assert!(result.abs() < 1e-10, "Expected 0, got {}", result);
|
||||
}
|
||||
}
|
||||
606
calcpad-engine/src/units/mod.rs
Normal file
606
calcpad-engine/src/units/mod.rs
Normal file
@@ -0,0 +1,606 @@
|
||||
//! Unit conversion system for calcpad-engine.
|
||||
//!
|
||||
//! Provides a comprehensive unit registry with 200+ built-in units across 13 categories,
|
||||
//! SI prefix decomposition, CSS/screen unit support with configurable PPI, case-sensitive
|
||||
//! data unit handling (B vs b), and custom user-defined units.
|
||||
|
||||
pub mod categories;
|
||||
pub mod css;
|
||||
pub mod custom;
|
||||
pub mod data;
|
||||
pub mod registry;
|
||||
pub mod si_prefix;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub use categories::UnitCategory;
|
||||
pub use css::CssConfig;
|
||||
pub use custom::CustomUnitRegistry;
|
||||
|
||||
/// How to convert a unit to/from its category's base unit.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Conversion {
|
||||
/// Linear: value_in_base = value * factor
|
||||
/// e.g., 1 km = 1000 m, so factor = 1000.0
|
||||
Linear(f64),
|
||||
|
||||
/// Formula-based (non-linear): used for temperature etc.
|
||||
/// to_base(value) and from_base(value) are function pointers.
|
||||
Formula {
|
||||
to_base: fn(f64) -> f64,
|
||||
from_base: fn(f64) -> f64,
|
||||
},
|
||||
}
|
||||
|
||||
/// A unit definition in the registry.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnitDef {
|
||||
/// Canonical name (e.g., "meter").
|
||||
pub name: &'static str,
|
||||
/// Short abbreviation for display (e.g., "m").
|
||||
pub abbreviation: &'static str,
|
||||
/// Category this unit belongs to.
|
||||
pub category: UnitCategory,
|
||||
/// How to convert to/from the base unit.
|
||||
pub conversion: Conversion,
|
||||
}
|
||||
|
||||
impl UnitDef {
|
||||
/// Convert a value in this unit to the base unit.
|
||||
pub fn to_base(&self, value: f64) -> f64 {
|
||||
match &self.conversion {
|
||||
Conversion::Linear(factor) => value * factor,
|
||||
Conversion::Formula { to_base, .. } => to_base(value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a value from the base unit to this unit.
|
||||
pub fn from_base(&self, value: f64) -> f64 {
|
||||
match &self.conversion {
|
||||
Conversion::Linear(factor) => value / factor,
|
||||
Conversion::Formula { from_base, .. } => from_base(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A resolved unit: a base unit definition combined with an SI prefix factor.
|
||||
#[derive(Debug)]
|
||||
pub struct ResolvedUnit<'a> {
|
||||
/// The base unit definition.
|
||||
pub unit: &'a UnitDef,
|
||||
/// The SI prefix multiplication factor (1.0 if no prefix).
|
||||
pub prefix_factor: f64,
|
||||
}
|
||||
|
||||
impl<'a> ResolvedUnit<'a> {
|
||||
/// Convert a value in this (possibly prefixed) unit to the category's base unit.
|
||||
pub fn to_base(&self, value: f64) -> f64 {
|
||||
self.unit.to_base(value * self.prefix_factor)
|
||||
}
|
||||
|
||||
/// Convert a value from the category's base unit to this (possibly prefixed) unit.
|
||||
pub fn from_base(&self, value: f64) -> f64 {
|
||||
self.unit.from_base(value) / self.prefix_factor
|
||||
}
|
||||
}
|
||||
|
||||
/// The unit registry -- maps names/aliases to unit definitions.
|
||||
pub struct UnitRegistry {
|
||||
/// Maps lowercase name/alias -> index into `units` vec.
|
||||
lookup: HashMap<String, usize>,
|
||||
/// Maps exact-case name/alias -> index for case-sensitive units (e.g., data: B vs b).
|
||||
case_sensitive_lookup: HashMap<String, usize>,
|
||||
/// All registered unit definitions.
|
||||
units: Vec<UnitDef>,
|
||||
}
|
||||
|
||||
impl UnitRegistry {
|
||||
fn new() -> Self {
|
||||
let mut reg = UnitRegistry {
|
||||
lookup: HashMap::new(),
|
||||
case_sensitive_lookup: HashMap::new(),
|
||||
units: Vec::new(),
|
||||
};
|
||||
registry::register_all(&mut reg);
|
||||
reg
|
||||
}
|
||||
|
||||
/// Register a unit with its aliases.
|
||||
pub(crate) fn register(&mut self, def: UnitDef, aliases: &[&str]) {
|
||||
let idx = self.units.len();
|
||||
// Register canonical name
|
||||
self.lookup.insert(def.name.to_lowercase(), idx);
|
||||
// Register abbreviation
|
||||
self.lookup.insert(def.abbreviation.to_lowercase(), idx);
|
||||
// Register additional aliases
|
||||
for alias in aliases {
|
||||
self.lookup.insert(alias.to_lowercase(), idx);
|
||||
}
|
||||
self.units.push(def);
|
||||
}
|
||||
|
||||
/// Register a case-sensitive alias for a unit already in the registry.
|
||||
/// Used for data units where case distinguishes bytes (B) from bits (b).
|
||||
pub(crate) fn register_case_sensitive_alias(&mut self, alias: &str, canonical_name: &str) {
|
||||
if let Some(&idx) = self.lookup.get(&canonical_name.to_lowercase()) {
|
||||
self.case_sensitive_lookup.insert(alias.to_string(), idx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a unit by any name or alias.
|
||||
/// Checks case-sensitive overrides first (for data units), then falls back to
|
||||
/// case-insensitive lookup.
|
||||
pub fn lookup(&self, name: &str) -> Option<&UnitDef> {
|
||||
// Case-sensitive lookup first (e.g., "Mb" -> megabit, "MB" -> megabyte)
|
||||
if let Some(&idx) = self.case_sensitive_lookup.get(name) {
|
||||
return Some(&self.units[idx]);
|
||||
}
|
||||
// Fall back to case-insensitive
|
||||
self.lookup
|
||||
.get(&name.to_lowercase())
|
||||
.map(|&idx| &self.units[idx])
|
||||
}
|
||||
|
||||
/// Resolve a unit name, trying direct lookup first, then SI prefix decomposition.
|
||||
///
|
||||
/// Returns a `ResolvedUnit` with the base unit definition and prefix factor.
|
||||
/// Pre-registered units always take priority over prefix decomposition.
|
||||
pub fn resolve_with_prefix(&self, name: &str) -> Option<ResolvedUnit<'_>> {
|
||||
// 1. Try direct lookup first (pre-registered units take priority)
|
||||
if let Some(unit) = self.lookup(name) {
|
||||
return Some(ResolvedUnit {
|
||||
unit,
|
||||
prefix_factor: 1.0,
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Try short-form SI prefix decomposition (case-sensitive on original input)
|
||||
if let Some((prefix, base_abbrev)) = si_prefix::match_short_prefix(name) {
|
||||
if let Some(unit) = self.lookup(base_abbrev) {
|
||||
return Some(ResolvedUnit {
|
||||
unit,
|
||||
prefix_factor: prefix.factor,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Try long-form SI prefix decomposition
|
||||
if let Some((prefix, base_name)) = si_prefix::match_long_prefix(name) {
|
||||
if let Some(unit) = self.lookup(base_name) {
|
||||
return Some(ResolvedUnit {
|
||||
unit,
|
||||
prefix_factor: prefix.factor,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get all registered unit definitions.
|
||||
pub fn all_units(&self) -> &[UnitDef] {
|
||||
&self.units
|
||||
}
|
||||
|
||||
/// Count total registered units.
|
||||
pub fn unit_count(&self) -> usize {
|
||||
self.units.len()
|
||||
}
|
||||
|
||||
/// Get all supported categories.
|
||||
pub fn categories(&self) -> Vec<UnitCategory> {
|
||||
let mut cats: Vec<UnitCategory> = Vec::new();
|
||||
for unit in &self.units {
|
||||
if !cats.contains(&unit.category) {
|
||||
cats.push(unit.category);
|
||||
}
|
||||
}
|
||||
cats
|
||||
}
|
||||
}
|
||||
|
||||
/// Global static registry instance.
|
||||
static REGISTRY: LazyLock<UnitRegistry> = LazyLock::new(UnitRegistry::new);
|
||||
|
||||
/// Get the global unit registry.
|
||||
pub fn registry() -> &'static UnitRegistry {
|
||||
®ISTRY
|
||||
}
|
||||
|
||||
/// Convert a value from one unit to another.
|
||||
/// Supports SI-prefixed units (e.g., "km", "MHz", "nanoseconds").
|
||||
/// Returns Err if units are incompatible or not found.
|
||||
pub fn convert(value: f64, from: &str, to: &str) -> Result<f64, String> {
|
||||
let reg = registry();
|
||||
|
||||
let from_resolved = reg
|
||||
.resolve_with_prefix(from)
|
||||
.ok_or_else(|| format!("Unknown unit: {}", from))?;
|
||||
let to_resolved = reg
|
||||
.resolve_with_prefix(to)
|
||||
.ok_or_else(|| format!("Unknown unit: {}", to))?;
|
||||
|
||||
if from_resolved.unit.category != to_resolved.unit.category {
|
||||
return Err(format!(
|
||||
"Cannot convert between {} ({}) and {} ({})",
|
||||
from_resolved.unit.name,
|
||||
from_resolved.unit.category,
|
||||
to_resolved.unit.name,
|
||||
to_resolved.unit.category,
|
||||
));
|
||||
}
|
||||
|
||||
// Convert: from_unit (with prefix) -> base -> to_unit (with prefix)
|
||||
let base_value = from_resolved.to_base(value);
|
||||
let result = to_resolved.from_base(base_value);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert a value from one unit to another, with CSS config for screen units.
|
||||
/// Uses CSS-specific conversion for CSS units, standard conversion otherwise.
|
||||
pub fn convert_with_config(
|
||||
value: f64,
|
||||
from: &str,
|
||||
to: &str,
|
||||
css_config: &CssConfig,
|
||||
) -> Result<f64, String> {
|
||||
let reg = registry();
|
||||
|
||||
let from_resolved = reg
|
||||
.resolve_with_prefix(from)
|
||||
.ok_or_else(|| format!("Unknown unit: {}", from))?;
|
||||
let to_resolved = reg
|
||||
.resolve_with_prefix(to)
|
||||
.ok_or_else(|| format!("Unknown unit: {}", to))?;
|
||||
|
||||
if from_resolved.unit.category != to_resolved.unit.category {
|
||||
return Err(format!(
|
||||
"Cannot convert between {} ({}) and {} ({})",
|
||||
from_resolved.unit.name,
|
||||
from_resolved.unit.category,
|
||||
to_resolved.unit.name,
|
||||
to_resolved.unit.category,
|
||||
));
|
||||
}
|
||||
|
||||
// For CSS units, use config-aware conversion
|
||||
if from_resolved.unit.category == UnitCategory::CssScreen {
|
||||
return css::convert_css(value, from, to, css_config);
|
||||
}
|
||||
|
||||
// Standard conversion through base unit
|
||||
let base_value = from_resolved.to_base(value);
|
||||
Ok(to_resolved.from_base(base_value))
|
||||
}
|
||||
|
||||
/// Check if a unit name belongs to the CSS/screen category.
|
||||
pub fn is_css_unit(name: &str) -> bool {
|
||||
let reg = registry();
|
||||
if let Some(resolved) = reg.resolve_with_prefix(name) {
|
||||
resolved.unit.category == UnitCategory::CssScreen
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ─── Registry fundamentals ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_registry_has_13_categories() {
|
||||
let reg = registry();
|
||||
let cats = reg.categories();
|
||||
assert!(
|
||||
cats.len() >= 13,
|
||||
"Expected at least 13 categories, got {}",
|
||||
cats.len()
|
||||
);
|
||||
for cat in UnitCategory::all() {
|
||||
assert!(cats.contains(cat), "Missing category: {}", cat);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_has_200_plus_units() {
|
||||
let reg = registry();
|
||||
assert!(
|
||||
reg.unit_count() >= 200,
|
||||
"Expected at least 200 units, got {}",
|
||||
reg.unit_count()
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Basic conversions ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_linear_conversion_miles_to_km() {
|
||||
let result = convert(5.0, "miles", "km").unwrap();
|
||||
assert!(
|
||||
(result - 8.04672).abs() < 1e-4,
|
||||
"Expected 8.04672, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_km_to_miles() {
|
||||
let result = convert(5.0, "km", "miles").unwrap();
|
||||
assert!(
|
||||
(result - 3.10686).abs() < 1e-4,
|
||||
"Expected ~3.10686, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_fahrenheit_to_celsius() {
|
||||
let result = convert(100.0, "F", "C").unwrap();
|
||||
let expected = (100.0 - 32.0) * 5.0 / 9.0;
|
||||
assert!(
|
||||
(result - expected).abs() < 1e-10,
|
||||
"Expected {}, got {}",
|
||||
expected,
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_celsius_to_fahrenheit() {
|
||||
let result = convert(0.0, "C", "F").unwrap();
|
||||
assert!((result - 32.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_kelvin_roundtrip() {
|
||||
let result = convert(100.0, "C", "K").unwrap();
|
||||
assert!((result - 373.15).abs() < 1e-10);
|
||||
let back = convert(result, "K", "C").unwrap();
|
||||
assert!((back - 100.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_volume_gallon_to_liters() {
|
||||
let result = convert(1.0, "gallon", "liters").unwrap();
|
||||
assert!(
|
||||
(result - 3.78541).abs() < 1e-4,
|
||||
"Expected 3.78541, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_incompatible_categories() {
|
||||
let result = convert(5.0, "kg", "meters");
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.contains("mass"), "Error should mention mass: {}", err);
|
||||
assert!(err.contains("length"), "Error should mention length: {}", err);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_aliases_resolve_same() {
|
||||
let reg = registry();
|
||||
let m1 = reg.lookup("meter").unwrap();
|
||||
let m2 = reg.lookup("metre").unwrap();
|
||||
let m3 = reg.lookup("m").unwrap();
|
||||
assert_eq!(m1.name, m2.name);
|
||||
assert_eq!(m2.name, m3.name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive_lookup() {
|
||||
let reg = registry();
|
||||
assert!(reg.lookup("KM").is_some());
|
||||
assert!(reg.lookup("Km").is_some());
|
||||
assert!(reg.lookup("km").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_unit() {
|
||||
let result = convert(1.0, "frobnitz", "meters");
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Unknown unit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identity_conversion() {
|
||||
let result = convert(42.0, "meters", "meters").unwrap();
|
||||
assert!((result - 42.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
// ─── SI prefix integration ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_all_prefix_factors_via_conversion() {
|
||||
let r = convert(1.0, "nm", "m").unwrap();
|
||||
assert!((r - 1e-9).abs() < 1e-20, "nano: got {}", r);
|
||||
|
||||
let r = convert(1.0, "ms", "s").unwrap();
|
||||
assert!((r - 1e-3).abs() < 1e-14, "milli: got {}", r);
|
||||
|
||||
let r = convert(1.0, "cm", "m").unwrap();
|
||||
assert!((r - 1e-2).abs() < 1e-14, "centi: got {}", r);
|
||||
|
||||
let r = convert(1.0, "km", "m").unwrap();
|
||||
assert!((r - 1e3).abs() < 1e-8, "kilo: got {}", r);
|
||||
|
||||
let r = convert(1.0, "MB", "B").unwrap();
|
||||
assert!((r - 1e6).abs() < 1e-4, "mega: got {}", r);
|
||||
|
||||
let r = convert(1.0, "GB", "B").unwrap();
|
||||
assert!((r - 1e9).abs() < 1e-1, "giga: got {}", r);
|
||||
|
||||
let r = convert(1.0, "TB", "B").unwrap();
|
||||
assert!((r - 1e12).abs() < 1e2, "tera: got {}", r);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_direct_lookup_priority() {
|
||||
let reg = registry();
|
||||
let resolved = reg.resolve_with_prefix("km").unwrap();
|
||||
assert_eq!(resolved.prefix_factor, 1.0);
|
||||
assert_eq!(resolved.unit.name, "kilometer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_novel_prefix_combination() {
|
||||
let reg = registry();
|
||||
let resolved = reg.resolve_with_prefix("nJ");
|
||||
assert!(resolved.is_some(), "nJ should resolve via SI prefix");
|
||||
let r = resolved.unwrap();
|
||||
assert_eq!(r.unit.name, "joule");
|
||||
assert_eq!(r.prefix_factor, 1e-9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_long_form_novel() {
|
||||
let reg = registry();
|
||||
let resolved = reg.resolve_with_prefix("terawatts");
|
||||
assert!(resolved.is_some(), "terawatts should resolve via long-form SI prefix");
|
||||
let r = resolved.unwrap();
|
||||
assert_eq!(r.unit.name, "watt");
|
||||
assert_eq!(r.prefix_factor, 1e12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kilofahrenheit_rejected() {
|
||||
let result = convert(1.0, "kilofahrenheit", "celsius");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kilomiles_rejected() {
|
||||
let result = convert(1.0, "kilomiles", "meters");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_sign_u00b5_in_convert() {
|
||||
let result = convert(3.0, "\u{00B5}s", "ms").unwrap();
|
||||
assert!(
|
||||
(result - 0.003).abs() < 1e-10,
|
||||
"Expected 0.003, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Data unit case sensitivity ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_case_sensitive_b_vs_big_b() {
|
||||
let reg = registry();
|
||||
let big_b = reg.lookup("B").unwrap();
|
||||
assert_eq!(big_b.name, "byte");
|
||||
let small_b = reg.lookup("b").unwrap();
|
||||
assert_eq!(small_b.name, "bit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_sensitive_mb_vs_big_mb() {
|
||||
let reg = registry();
|
||||
let big_mb = reg.lookup("MB").unwrap();
|
||||
assert_eq!(big_mb.name, "megabyte");
|
||||
let small_mb = reg.lookup("Mb").unwrap();
|
||||
assert_eq!(small_mb.name, "megabit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_megabytes_to_megabits() {
|
||||
let result = convert(1.0, "MB", "Mb").unwrap();
|
||||
assert!((result - 8.0).abs() < 1e-10, "Expected 8, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_mib_to_kib() {
|
||||
let result = convert(1.0, "MiB", "KiB").unwrap();
|
||||
assert!((result - 1024.0).abs() < 1e-10, "Expected 1024, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_convert_gib_to_gb() {
|
||||
let result = convert(5.0, "GiB", "GB").unwrap();
|
||||
let expected = 5.0 * 1_073_741_824.0 / 1e9;
|
||||
assert!((result - expected).abs() < 1e-6, "Expected {}, got {}", expected, result);
|
||||
}
|
||||
|
||||
// ─── CSS unit tests ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_css_12pt_to_px_default() {
|
||||
let config = CssConfig::default();
|
||||
let result = css::convert_css(12.0, "pt", "px", &config).unwrap();
|
||||
assert!((result - 16.0).abs() < 1e-10, "12pt should be 16px at PPI=96, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_css_2em_to_px_default() {
|
||||
let config = CssConfig::default();
|
||||
let result = css::convert_css(2.0, "em", "px", &config).unwrap();
|
||||
assert!((result - 32.0).abs() < 1e-10, "2em should be 32px, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_css_12pt_to_px_retina() {
|
||||
let config = CssConfig { ppi: 326.0, em_base_px: 16.0 };
|
||||
let result = css::convert_css(12.0, "pt", "px", &config).unwrap();
|
||||
let expected = 12.0 * 326.0 / 72.0;
|
||||
assert!((result - expected).abs() < 1e-10, "Expected {}, got {}", expected, result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_css_2em_custom_base() {
|
||||
let config = CssConfig { ppi: 96.0, em_base_px: 20.0 };
|
||||
let result = css::convert_css(2.0, "em", "px", &config).unwrap();
|
||||
assert!((result - 40.0).abs() < 1e-10, "2em at em=20 should be 40px, got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_css_units_in_css_category() {
|
||||
let reg = registry();
|
||||
for name in &["px", "pt", "em", "rem"] {
|
||||
let unit = reg.lookup(name).unwrap();
|
||||
assert_eq!(unit.category, UnitCategory::CssScreen, "{} should be CssScreen", name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_css_incompatible_with_length() {
|
||||
let config = CssConfig::default();
|
||||
let result = convert_with_config(1.0, "px", "meters", &config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_css_unit() {
|
||||
assert!(is_css_unit("px"));
|
||||
assert!(is_css_unit("pt"));
|
||||
assert!(is_css_unit("em"));
|
||||
assert!(is_css_unit("rem"));
|
||||
assert!(!is_css_unit("kg"));
|
||||
assert!(!is_css_unit("meters"));
|
||||
}
|
||||
|
||||
// ─── Performance ────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_o1_lookup_performance() {
|
||||
let reg = registry();
|
||||
let _ = reg.lookup("meter");
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..10_000 {
|
||||
let _ = reg.lookup("kilometer");
|
||||
let _ = reg.lookup("lb");
|
||||
let _ = reg.lookup("F");
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
assert!(
|
||||
elapsed.as_millis() < 50,
|
||||
"Lookup too slow: {}ms for 30k lookups",
|
||||
elapsed.as_millis()
|
||||
);
|
||||
}
|
||||
}
|
||||
422
calcpad-engine/src/units/registry.rs
Normal file
422
calcpad-engine/src/units/registry.rs
Normal file
@@ -0,0 +1,422 @@
|
||||
//! Unit registry population -- registers all built-in units across all categories.
|
||||
|
||||
use super::categories::UnitCategory;
|
||||
use super::{Conversion, UnitDef, UnitRegistry};
|
||||
|
||||
/// Register all units across all categories.
|
||||
pub(crate) fn register_all(reg: &mut UnitRegistry) {
|
||||
register_length(reg);
|
||||
register_mass(reg);
|
||||
register_volume(reg);
|
||||
register_area(reg);
|
||||
register_speed(reg);
|
||||
register_temperature(reg);
|
||||
super::data::register_data(reg);
|
||||
register_angle(reg);
|
||||
register_time(reg);
|
||||
register_pressure(reg);
|
||||
register_energy(reg);
|
||||
register_power(reg);
|
||||
register_force(reg);
|
||||
super::css::register_css_screen(reg);
|
||||
}
|
||||
|
||||
/// Helper to register a linear unit.
|
||||
fn linear(
|
||||
reg: &mut UnitRegistry,
|
||||
name: &'static str,
|
||||
abbrev: &'static str,
|
||||
category: UnitCategory,
|
||||
factor: f64,
|
||||
aliases: &[&str],
|
||||
) {
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name,
|
||||
abbreviation: abbrev,
|
||||
category,
|
||||
conversion: Conversion::Linear(factor),
|
||||
},
|
||||
aliases,
|
||||
);
|
||||
}
|
||||
|
||||
// ─── LENGTH (base: meter) ────────────────────────────────────────────
|
||||
|
||||
fn register_length(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Length;
|
||||
|
||||
linear(reg, "meter", "m", c, 1.0, &["meters", "metre", "metres"]);
|
||||
linear(reg, "kilometer", "km", c, 1000.0, &["kilometers", "kilometre", "kilometres"]);
|
||||
linear(reg, "centimeter", "cm", c, 0.01, &["centimeters", "centimetre", "centimetres"]);
|
||||
linear(reg, "millimeter", "mm", c, 0.001, &["millimeters", "millimetre", "millimetres"]);
|
||||
linear(reg, "micrometer", "\u{03BC}m", c, 1e-6, &["micrometers", "micrometre", "micrometres", "micron", "microns"]);
|
||||
linear(reg, "nanometer", "nm", c, 1e-9, &["nanometers", "nanometre", "nanometres"]);
|
||||
linear(reg, "picometer", "pm", c, 1e-12, &["picometers", "picometre", "picometres"]);
|
||||
linear(reg, "decimeter", "dm", c, 0.1, &["decimeters", "decimetre", "decimetres"]);
|
||||
linear(reg, "hectometer", "hm", c, 100.0, &["hectometers", "hectometre", "hectometres"]);
|
||||
linear(reg, "mile", "mi", c, 1609.344, &["miles"]);
|
||||
linear(reg, "yard", "yd", c, 0.9144, &["yards"]);
|
||||
linear(reg, "foot", "ft", c, 0.3048, &["feet"]);
|
||||
linear(reg, "inch", "in", c, 0.0254, &["inches"]);
|
||||
linear(reg, "nautical mile", "nmi", c, 1852.0, &["nautical miles", "nauticalmile", "nauticalmiles"]);
|
||||
linear(reg, "fathom", "ftm", c, 1.8288, &["fathoms"]);
|
||||
linear(reg, "furlong", "fur", c, 201.168, &["furlongs"]);
|
||||
linear(reg, "chain", "ch", c, 20.1168, &["chains"]);
|
||||
linear(reg, "rod", "rd", c, 5.0292, &["rods", "perch", "pole"]);
|
||||
linear(reg, "league", "lea", c, 4828.032, &["leagues"]);
|
||||
linear(reg, "thou", "th", c, 0.0000254, &["mil", "mils"]);
|
||||
linear(reg, "angstrom", "\u{00C5}", c, 1e-10, &["angstroms"]);
|
||||
linear(reg, "light-year", "ly", c, 9.461e15, &["lightyear", "lightyears", "light-years"]);
|
||||
linear(reg, "astronomical unit", "au", c, 1.496e11, &["astronomical units", "astronomicalunit"]);
|
||||
linear(reg, "parsec", "pc", c, 3.086e16, &["parsecs"]);
|
||||
}
|
||||
|
||||
// ─── MASS (base: kilogram) ───────────────────────────────────────────
|
||||
|
||||
fn register_mass(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Mass;
|
||||
|
||||
linear(reg, "kilogram", "kg", c, 1.0, &["kilograms", "kilo", "kilos"]);
|
||||
linear(reg, "gram", "g", c, 0.001, &["grams", "gm"]);
|
||||
linear(reg, "milligram", "mg", c, 1e-6, &["milligrams"]);
|
||||
linear(reg, "microgram", "\u{03BC}g", c, 1e-9, &["micrograms", "mcg"]);
|
||||
linear(reg, "metric ton", "t", c, 1000.0, &["tonne", "tonnes", "metric tons"]);
|
||||
linear(reg, "pound", "lb", c, 0.45359237, &["pounds", "lbs"]);
|
||||
linear(reg, "ounce", "oz", c, 0.028349523125, &["ounces"]);
|
||||
linear(reg, "stone", "st", c, 6.35029318, &["stones"]);
|
||||
linear(reg, "short ton", "ton", c, 907.18474, &["tons", "short tons", "us ton"]);
|
||||
linear(reg, "long ton", "long ton", c, 1016.0469088, &["long tons", "imperial ton"]);
|
||||
linear(reg, "carat", "ct", c, 0.0002, &["carats"]);
|
||||
linear(reg, "grain", "gr", c, 0.00006479891, &["grains"]);
|
||||
linear(reg, "dram", "dr", c, 0.001771845195, &["drams"]);
|
||||
linear(reg, "hundredweight", "cwt", c, 45.359237, &["hundredweights"]);
|
||||
linear(reg, "slug", "slug", c, 14.593903, &["slugs"]);
|
||||
linear(reg, "atomic mass unit", "amu", c, 1.66053906660e-27, &["dalton", "daltons", "u"]);
|
||||
linear(reg, "decigram", "dg", c, 0.0001, &["decigrams"]);
|
||||
linear(reg, "centigram", "cg", c, 0.00001, &["centigrams"]);
|
||||
linear(reg, "quintal", "q", c, 100.0, &["quintals"]);
|
||||
linear(reg, "pennyweight", "dwt", c, 0.00155517384, &["pennyweights"]);
|
||||
linear(reg, "troy ounce", "oz t", c, 0.0311034768, &["troy ounces"]);
|
||||
linear(reg, "troy pound", "lb t", c, 0.3732417216, &["troy pounds"]);
|
||||
}
|
||||
|
||||
// ─── VOLUME (base: liter) ────────────────────────────────────────────
|
||||
|
||||
fn register_volume(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Volume;
|
||||
|
||||
linear(reg, "liter", "L", c, 1.0, &["liters", "litre", "litres", "l"]);
|
||||
linear(reg, "milliliter", "mL", c, 0.001, &["milliliters", "millilitre", "millilitres", "ml"]);
|
||||
linear(reg, "centiliter", "cL", c, 0.01, &["centiliters", "centilitre", "centilitres", "cl"]);
|
||||
linear(reg, "deciliter", "dL", c, 0.1, &["deciliters", "decilitre", "decilitres", "dl"]);
|
||||
linear(reg, "hectoliter", "hL", c, 100.0, &["hectoliters", "hectolitre", "hectolitres", "hl"]);
|
||||
linear(reg, "kiloliter", "kL", c, 1000.0, &["kiloliters", "kilolitre", "kilolitres", "kl"]);
|
||||
linear(reg, "cubic meter", "m\u{00B3}", c, 1000.0, &["cubic meters", "m3", "cbm"]);
|
||||
linear(reg, "cubic centimeter", "cm\u{00B3}", c, 0.001, &["cubic centimeters", "cm3", "cc"]);
|
||||
linear(reg, "cubic millimeter", "mm\u{00B3}", c, 1e-6, &["cubic millimeters", "mm3"]);
|
||||
linear(reg, "cubic inch", "in\u{00B3}", c, 0.016387064, &["cubic inches", "in3"]);
|
||||
linear(reg, "cubic foot", "ft\u{00B3}", c, 28.316846592, &["cubic feet", "ft3"]);
|
||||
linear(reg, "cubic yard", "yd\u{00B3}", c, 764.554857984, &["cubic yards", "yd3"]);
|
||||
linear(reg, "gallon", "gal", c, 3.785411784, &["gallons", "us gallon", "us gallons"]);
|
||||
linear(reg, "quart", "qt", c, 0.946352946, &["quarts"]);
|
||||
linear(reg, "pint", "pt", c, 0.473176473, &["pints"]);
|
||||
linear(reg, "cup", "cup", c, 0.2365882365, &["cups"]);
|
||||
linear(reg, "fluid ounce", "fl oz", c, 0.0295735295625, &["fluid ounces", "floz"]);
|
||||
linear(reg, "tablespoon", "tbsp", c, 0.01478676478125, &["tablespoons", "tbs"]);
|
||||
linear(reg, "teaspoon", "tsp", c, 0.00492892159375, &["teaspoons"]);
|
||||
linear(reg, "imperial gallon", "imp gal", c, 4.54609, &["imperial gallons", "uk gallon", "uk gallons"]);
|
||||
linear(reg, "imperial quart", "imp qt", c, 1.1365225, &["imperial quarts"]);
|
||||
linear(reg, "imperial pint", "imp pt", c, 0.56826125, &["imperial pints"]);
|
||||
linear(reg, "imperial fluid ounce", "imp fl oz", c, 0.0284130625, &["imperial fluid ounces"]);
|
||||
linear(reg, "barrel", "bbl", c, 158.987294928, &["barrels", "oil barrel"]);
|
||||
linear(reg, "bushel", "bu", c, 35.23907016688, &["bushels"]);
|
||||
linear(reg, "gill", "gi", c, 0.1182941183, &["gills"]);
|
||||
linear(reg, "minim", "minim", c, 0.00006161152, &["minims"]);
|
||||
linear(reg, "dram (fluid)", "fl dr", c, 0.003696691, &["fluid drams", "fluid dram"]);
|
||||
linear(reg, "hogshead", "hhd", c, 238.480942392, &["hogsheads"]);
|
||||
}
|
||||
|
||||
// ─── AREA (base: square meter) ───────────────────────────────────────
|
||||
|
||||
fn register_area(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Area;
|
||||
|
||||
linear(reg, "square meter", "m\u{00B2}", c, 1.0, &["square meters", "sq m", "sqm", "m2"]);
|
||||
linear(reg, "square kilometer", "km\u{00B2}", c, 1e6, &["square kilometers", "sq km", "sqkm", "km2"]);
|
||||
linear(reg, "square centimeter", "cm\u{00B2}", c, 1e-4, &["square centimeters", "sq cm", "sqcm", "cm2"]);
|
||||
linear(reg, "square millimeter", "mm\u{00B2}", c, 1e-6, &["square millimeters", "sq mm", "sqmm", "mm2"]);
|
||||
linear(reg, "hectare", "ha", c, 10000.0, &["hectares"]);
|
||||
linear(reg, "acre", "ac", c, 4046.8564224, &["acres"]);
|
||||
linear(reg, "square mile", "mi\u{00B2}", c, 2_589_988.110336, &["square miles", "sq mi", "sqmi", "mi2"]);
|
||||
linear(reg, "square yard", "yd\u{00B2}", c, 0.83612736, &["square yards", "sq yd", "sqyd", "yd2"]);
|
||||
linear(reg, "square foot", "ft\u{00B2}", c, 0.09290304, &["square feet", "sq ft", "sqft", "ft2"]);
|
||||
linear(reg, "square inch", "in\u{00B2}", c, 0.00064516, &["square inches", "sq in", "sqin", "in2"]);
|
||||
linear(reg, "are", "a", c, 100.0, &["ares"]);
|
||||
linear(reg, "barn", "b", c, 1e-28, &["barns"]);
|
||||
linear(reg, "dunam", "dunam", c, 1000.0, &["dunams", "dunum"]);
|
||||
linear(reg, "township", "twp", c, 93_239_571.972, &["townships"]);
|
||||
linear(reg, "rood", "rood", c, 1011.7141056, &["roods"]);
|
||||
}
|
||||
|
||||
// ─── SPEED (base: meter per second) ──────────────────────────────────
|
||||
|
||||
fn register_speed(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Speed;
|
||||
|
||||
linear(reg, "meter per second", "m/s", c, 1.0, &["meters per second", "mps"]);
|
||||
linear(reg, "kilometer per hour", "km/h", c, 1.0 / 3.6, &["kilometers per hour", "kph", "kmh", "kmph"]);
|
||||
linear(reg, "mile per hour", "mph", c, 0.44704, &["miles per hour"]);
|
||||
linear(reg, "knot", "kn", c, 0.514444, &["knots", "kt"]);
|
||||
linear(reg, "foot per second", "ft/s", c, 0.3048, &["feet per second", "fps"]);
|
||||
linear(reg, "centimeter per second", "cm/s", c, 0.01, &["centimeters per second"]);
|
||||
linear(reg, "mach", "Ma", c, 340.29, &["machs"]);
|
||||
linear(reg, "speed of light", "c", c, 299_792_458.0, &[]);
|
||||
linear(reg, "inch per second", "in/s", c, 0.0254, &["inches per second"]);
|
||||
linear(reg, "yard per second", "yd/s", c, 0.9144, &["yards per second"]);
|
||||
linear(reg, "mile per second", "mi/s", c, 1609.344, &["miles per second"]);
|
||||
}
|
||||
|
||||
// ─── TEMPERATURE (base: kelvin) ──────────────────────────────────────
|
||||
|
||||
fn register_temperature(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Temperature;
|
||||
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name: "kelvin",
|
||||
abbreviation: "K",
|
||||
category: c,
|
||||
conversion: Conversion::Linear(1.0),
|
||||
},
|
||||
&["kelvins"],
|
||||
);
|
||||
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name: "celsius",
|
||||
abbreviation: "\u{00B0}C",
|
||||
category: c,
|
||||
conversion: Conversion::Formula {
|
||||
to_base: |v| v + 273.15,
|
||||
from_base: |v| v - 273.15,
|
||||
},
|
||||
},
|
||||
&["degc", "degC", "\u{00B0}c", "C"],
|
||||
);
|
||||
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name: "fahrenheit",
|
||||
abbreviation: "\u{00B0}F",
|
||||
category: c,
|
||||
conversion: Conversion::Formula {
|
||||
to_base: |v| (v - 32.0) * 5.0 / 9.0 + 273.15,
|
||||
from_base: |v| (v - 273.15) * 9.0 / 5.0 + 32.0,
|
||||
},
|
||||
},
|
||||
&["degf", "degF", "\u{00B0}f", "F"],
|
||||
);
|
||||
|
||||
reg.register(
|
||||
UnitDef {
|
||||
name: "rankine",
|
||||
abbreviation: "\u{00B0}R",
|
||||
category: c,
|
||||
conversion: Conversion::Formula {
|
||||
to_base: |v| v / 1.8,
|
||||
from_base: |v| v * 1.8,
|
||||
},
|
||||
},
|
||||
&["degr", "degR", "\u{00B0}r", "R"],
|
||||
);
|
||||
}
|
||||
|
||||
// ─── ANGLE (base: radian) ───────────────────────────────────────────
|
||||
|
||||
fn register_angle(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Angle;
|
||||
|
||||
linear(reg, "radian", "rad", c, 1.0, &["radians"]);
|
||||
linear(reg, "degree", "deg", c, std::f64::consts::PI / 180.0, &["degrees", "\u{00B0}"]);
|
||||
linear(reg, "gradian", "gon", c, std::f64::consts::PI / 200.0, &["gradians", "grad", "grads"]);
|
||||
linear(reg, "arcminute", "arcmin", c, std::f64::consts::PI / 10800.0, &["arcminutes", "arc minute", "arc minutes", "MOA"]);
|
||||
linear(reg, "arcsecond", "arcsec", c, std::f64::consts::PI / 648000.0, &["arcseconds", "arc second", "arc seconds"]);
|
||||
linear(reg, "revolution", "rev", c, 2.0 * std::f64::consts::PI, &["revolutions", "turn", "turns"]);
|
||||
linear(reg, "milliradian", "mrad", c, 0.001, &["milliradians"]);
|
||||
}
|
||||
|
||||
// ─── TIME (base: second) ────────────────────────────────────────────
|
||||
|
||||
fn register_time(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Time;
|
||||
|
||||
linear(reg, "second", "s", c, 1.0, &["seconds", "sec", "secs"]);
|
||||
linear(reg, "millisecond", "ms", c, 0.001, &["milliseconds"]);
|
||||
linear(reg, "microsecond", "\u{03BC}s", c, 1e-6, &["microseconds", "us"]);
|
||||
linear(reg, "nanosecond", "ns", c, 1e-9, &["nanoseconds"]);
|
||||
linear(reg, "minute", "min", c, 60.0, &["minutes", "mins"]);
|
||||
linear(reg, "hour", "hr", c, 3600.0, &["hours", "hrs", "h"]);
|
||||
linear(reg, "day", "d", c, 86400.0, &["days"]);
|
||||
linear(reg, "week", "wk", c, 604800.0, &["weeks", "wks"]);
|
||||
linear(reg, "fortnight", "fn", c, 1_209_600.0, &["fortnights"]);
|
||||
linear(reg, "month", "mo", c, 2_629_746.0, &["months"]); // average month
|
||||
linear(reg, "year", "yr", c, 31_556_952.0, &["years", "yrs"]); // average year
|
||||
linear(reg, "decade", "dec", c, 315_569_520.0, &["decades"]);
|
||||
linear(reg, "century", "cent", c, 3_155_695_200.0, &["centuries"]);
|
||||
linear(reg, "millennium", "mill", c, 31_556_952_000.0, &["millennia", "millenniums"]);
|
||||
linear(reg, "picosecond", "ps", c, 1e-12, &["picoseconds"]);
|
||||
linear(reg, "shake", "shake", c, 1e-8, &["shakes"]);
|
||||
linear(reg, "sidereal day", "sid day", c, 86164.0905, &["sidereal days"]);
|
||||
}
|
||||
|
||||
// ─── PRESSURE (base: pascal) ─────────────────────────────────────────
|
||||
|
||||
fn register_pressure(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Pressure;
|
||||
|
||||
linear(reg, "pascal", "Pa", c, 1.0, &["pascals"]);
|
||||
linear(reg, "kilopascal", "kPa", c, 1000.0, &["kilopascals"]);
|
||||
linear(reg, "megapascal", "MPa", c, 1e6, &["megapascals"]);
|
||||
linear(reg, "gigapascal", "GPa", c, 1e9, &["gigapascals"]);
|
||||
linear(reg, "bar", "bar", c, 100_000.0, &["bars"]);
|
||||
linear(reg, "millibar", "mbar", c, 100.0, &["millibars"]);
|
||||
linear(reg, "atmosphere", "atm", c, 101_325.0, &["atmospheres"]);
|
||||
linear(reg, "pound per square inch", "psi", c, 6894.757293168, &["pounds per square inch"]);
|
||||
linear(reg, "torr", "Torr", c, 133.322368421, &["torrs"]);
|
||||
linear(reg, "millimeter of mercury", "mmHg", c, 133.322387415, &["millimeters of mercury", "mm Hg"]);
|
||||
linear(reg, "inch of mercury", "inHg", c, 3386.389, &["inches of mercury", "in Hg"]);
|
||||
linear(reg, "inch of water", "inH2O", c, 249.08891, &["inches of water", "in H2O"]);
|
||||
}
|
||||
|
||||
// ─── ENERGY (base: joule) ───────────────────────────────────────────
|
||||
|
||||
fn register_energy(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Energy;
|
||||
|
||||
linear(reg, "joule", "J", c, 1.0, &["joules"]);
|
||||
linear(reg, "kilojoule", "kJ", c, 1000.0, &["kilojoules"]);
|
||||
linear(reg, "megajoule", "MJ", c, 1e6, &["megajoules"]);
|
||||
linear(reg, "gigajoule", "GJ", c, 1e9, &["gigajoules"]);
|
||||
linear(reg, "calorie", "cal", c, 4.184, &["calories"]);
|
||||
linear(reg, "kilocalorie", "kcal", c, 4184.0, &["kilocalories", "Cal", "food calorie", "food calories"]);
|
||||
linear(reg, "watt-hour", "Wh", c, 3600.0, &["watt-hours", "watthour", "watthours"]);
|
||||
linear(reg, "kilowatt-hour", "kWh", c, 3_600_000.0, &["kilowatt-hours", "kilowatthour"]);
|
||||
linear(reg, "megawatt-hour", "MWh", c, 3.6e9, &["megawatt-hours"]);
|
||||
linear(reg, "british thermal unit", "BTU", c, 1055.05585262, &["btu", "btus", "british thermal units"]);
|
||||
linear(reg, "therm", "thm", c, 105_505_585.262, &["therms"]);
|
||||
linear(reg, "electronvolt", "eV", c, 1.602176634e-19, &["electronvolts", "electron volt"]);
|
||||
linear(reg, "kiloelectronvolt", "keV", c, 1.602176634e-16, &["kiloelectronvolts"]);
|
||||
linear(reg, "megaelectronvolt", "MeV", c, 1.602176634e-13, &["megaelectronvolts"]);
|
||||
linear(reg, "erg", "erg", c, 1e-7, &["ergs"]);
|
||||
linear(reg, "foot-pound", "ft\u{00B7}lbf", c, 1.3558179483, &["foot-pounds", "ft-lbf", "ftlbf", "foot pound"]);
|
||||
}
|
||||
|
||||
// ─── POWER (base: watt) ─────────────────────────────────────────────
|
||||
|
||||
fn register_power(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Power;
|
||||
|
||||
linear(reg, "watt", "W", c, 1.0, &["watts"]);
|
||||
linear(reg, "milliwatt", "mW", c, 0.001, &["milliwatts"]);
|
||||
linear(reg, "kilowatt", "kW", c, 1000.0, &["kilowatts"]);
|
||||
linear(reg, "megawatt", "MW", c, 1e6, &["megawatts"]);
|
||||
linear(reg, "gigawatt", "GW", c, 1e9, &["gigawatts"]);
|
||||
linear(reg, "horsepower", "hp", c, 745.69987158227022, &["horsepowers"]);
|
||||
linear(reg, "metric horsepower", "PS", c, 735.49875, &["metric horsepowers", "cv"]);
|
||||
linear(reg, "btu per hour", "BTU/h", c, 0.29307107017, &["btu/hr", "btus per hour"]);
|
||||
linear(reg, "foot-pound per second", "ft\u{00B7}lbf/s", c, 1.3558179483, &["foot-pounds per second", "ft-lbf/s"]);
|
||||
linear(reg, "ton of refrigeration", "TR", c, 3516.8528, &["tons of refrigeration"]);
|
||||
linear(reg, "volt-ampere", "VA", c, 1.0, &["volt-amperes"]);
|
||||
linear(reg, "kilovolt-ampere", "kVA", c, 1000.0, &["kilovolt-amperes"]);
|
||||
}
|
||||
|
||||
// ─── FORCE (base: newton) ───────────────────────────────────────────
|
||||
|
||||
fn register_force(reg: &mut UnitRegistry) {
|
||||
let c = UnitCategory::Force;
|
||||
|
||||
linear(reg, "newton", "N", c, 1.0, &["newtons"]);
|
||||
linear(reg, "kilonewton", "kN", c, 1000.0, &["kilonewtons"]);
|
||||
linear(reg, "meganewton", "MN", c, 1e6, &["meganewtons"]);
|
||||
linear(reg, "dyne", "dyn", c, 1e-5, &["dynes"]);
|
||||
linear(reg, "pound-force", "lbf", c, 4.4482216152605, &["pounds-force", "pound force"]);
|
||||
linear(reg, "kilogram-force", "kgf", c, 9.80665, &["kilograms-force", "kilogram force", "kilopond", "kp"]);
|
||||
linear(reg, "gram-force", "gf", c, 0.00980665, &["grams-force", "gram force"]);
|
||||
linear(reg, "ounce-force", "ozf", c, 0.278013851, &["ounces-force", "ounce force"]);
|
||||
linear(reg, "poundal", "pdl", c, 0.138254954376, &["poundals"]);
|
||||
linear(reg, "millinewton", "mN", c, 0.001, &["millinewtons"]);
|
||||
linear(reg, "micronewton", "\u{03BC}N", c, 1e-6, &["micronewtons"]);
|
||||
linear(reg, "sthene", "sn", c, 1000.0, &["sthenes"]);
|
||||
linear(reg, "kip", "kip", c, 4448.2216152605, &["kips", "kilopound-force"]);
|
||||
linear(reg, "ton-force", "tnf", c, 8896.443230521, &["tons-force", "ton force", "short ton-force"]);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_registry() -> UnitRegistry {
|
||||
let mut reg = UnitRegistry {
|
||||
lookup: HashMap::new(),
|
||||
case_sensitive_lookup: HashMap::new(),
|
||||
units: Vec::new(),
|
||||
};
|
||||
register_all(&mut reg);
|
||||
reg
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_categories_have_units() {
|
||||
let reg = make_registry();
|
||||
for cat in UnitCategory::all() {
|
||||
let count = reg.all_units().iter().filter(|u| u.category == *cat).count();
|
||||
assert!(count > 0, "Category {} has no units", cat);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_length_meter_is_base() {
|
||||
let reg = make_registry();
|
||||
let m = reg.lookup("meter").unwrap();
|
||||
assert_eq!(m.to_base(1.0), 1.0);
|
||||
assert_eq!(m.from_base(1.0), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_length_km_conversion() {
|
||||
let reg = make_registry();
|
||||
let km = reg.lookup("km").unwrap();
|
||||
assert!((km.to_base(1.0) - 1000.0).abs() < 1e-10);
|
||||
assert!((km.from_base(1000.0) - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_conversions() {
|
||||
let reg = make_registry();
|
||||
let c = reg.lookup("C").unwrap();
|
||||
let f = reg.lookup("F").unwrap();
|
||||
|
||||
assert!((c.to_base(0.0) - 273.15).abs() < 1e-10);
|
||||
assert!((f.to_base(32.0) - 273.15).abs() < 1e-10);
|
||||
assert!((f.to_base(212.0) - 373.15).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_builds_without_panic() {
|
||||
let reg = make_registry();
|
||||
assert!(reg.lookup.len() > 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_per_category() {
|
||||
let reg = make_registry();
|
||||
let mut total = 0;
|
||||
for cat in UnitCategory::all() {
|
||||
let count = reg.all_units().iter().filter(|u| u.category == *cat).count();
|
||||
total += count;
|
||||
}
|
||||
assert_eq!(total, reg.unit_count());
|
||||
}
|
||||
}
|
||||
327
calcpad-engine/src/units/si_prefix.rs
Normal file
327
calcpad-engine/src/units/si_prefix.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
//! SI prefix handling for unit decomposition.
|
||||
//!
|
||||
//! Supports prefixes from nano (10^-9) through tera (10^12).
|
||||
//! Short-form symbols are case-sensitive: "k" = kilo, "M" = mega, "m" = milli.
|
||||
//! Long-form names are case-insensitive: "kilo", "Kilo", "KILO" all match.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// An SI prefix definition.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SiPrefix {
|
||||
/// Long-form name (e.g., "kilo").
|
||||
pub name: &'static str,
|
||||
/// Short-form symbol (e.g., "k").
|
||||
pub symbol: &'static str,
|
||||
/// Multiplication factor (e.g., 1e3 for kilo).
|
||||
pub factor: f64,
|
||||
}
|
||||
|
||||
/// All supported SI prefixes from nano (10^-9) through tera (10^12).
|
||||
pub static SI_PREFIXES: &[SiPrefix] = &[
|
||||
SiPrefix { name: "tera", symbol: "T", factor: 1e12 },
|
||||
SiPrefix { name: "giga", symbol: "G", factor: 1e9 },
|
||||
SiPrefix { name: "mega", symbol: "M", factor: 1e6 },
|
||||
SiPrefix { name: "kilo", symbol: "k", factor: 1e3 },
|
||||
SiPrefix { name: "centi", symbol: "c", factor: 1e-2 },
|
||||
SiPrefix { name: "milli", symbol: "m", factor: 1e-3 },
|
||||
SiPrefix { name: "micro", symbol: "\u{00B5}", factor: 1e-6 },
|
||||
SiPrefix { name: "nano", symbol: "n", factor: 1e-9 },
|
||||
];
|
||||
|
||||
/// Base unit abbreviations that accept SI prefixes.
|
||||
/// Case-sensitive: "m" (meter), "g" (gram), "s" (second), etc.
|
||||
static SI_COMPATIBLE_ABBREVS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
HashSet::from([
|
||||
"m", // Length (meter)
|
||||
"g", // Mass (gram)
|
||||
"L", "l", // Volume (liter)
|
||||
"B", "b", "bit", // Data
|
||||
"s", // Time (second)
|
||||
"Pa", // Pressure (pascal)
|
||||
"J", // Energy (joule)
|
||||
"W", // Power (watt)
|
||||
"N", // Force (newton)
|
||||
"rad", // Angle (radian)
|
||||
"eV", // Energy (electronvolt)
|
||||
"Hz", // Frequency (for future use)
|
||||
])
|
||||
});
|
||||
|
||||
/// Base unit long-form names that accept SI prefixes (lowercase).
|
||||
static SI_COMPATIBLE_NAMES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
HashSet::from([
|
||||
"meter", "meters", "metre", "metres",
|
||||
"gram", "grams",
|
||||
"liter", "liters", "litre", "litres",
|
||||
"byte", "bytes", "bit", "bits",
|
||||
"second", "seconds",
|
||||
"pascal", "pascals",
|
||||
"joule", "joules",
|
||||
"watt", "watts",
|
||||
"newton", "newtons",
|
||||
"radian", "radians",
|
||||
"electronvolt", "electronvolts",
|
||||
"hertz",
|
||||
])
|
||||
});
|
||||
|
||||
/// Try to decompose a unit string into an SI prefix symbol + base unit abbreviation.
|
||||
/// Returns `(prefix, remaining_unit_str)` if a valid decomposition is found.
|
||||
///
|
||||
/// Uses case-sensitive matching for symbols:
|
||||
/// - "k" = kilo, "M" = mega, "m" = milli, "G" = giga, etc.
|
||||
pub fn match_short_prefix(input: &str) -> Option<(&'static SiPrefix, &str)> {
|
||||
for prefix in SI_PREFIXES {
|
||||
if let Some(remainder) = input.strip_prefix(prefix.symbol) {
|
||||
if !remainder.is_empty() && SI_COMPATIBLE_ABBREVS.contains(remainder) {
|
||||
return Some((prefix, remainder));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also try Greek small letter mu (U+03BC) as alias for micro sign (U+00B5)
|
||||
if let Some(remainder) = input.strip_prefix('\u{03BC}') {
|
||||
if !remainder.is_empty() && SI_COMPATIBLE_ABBREVS.contains(remainder) {
|
||||
let micro = SI_PREFIXES.iter().find(|p| p.name == "micro").unwrap();
|
||||
return Some((micro, remainder));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Try to decompose a unit string into an SI prefix long-form name + base unit name.
|
||||
/// Returns `(prefix, remaining_unit_str)` if a valid decomposition is found.
|
||||
///
|
||||
/// Case-insensitive matching for long-form names.
|
||||
pub fn match_long_prefix(input: &str) -> Option<(&'static SiPrefix, &str)> {
|
||||
let lower = input.to_lowercase();
|
||||
|
||||
for prefix in SI_PREFIXES {
|
||||
if let Some(remainder) = lower.strip_prefix(prefix.name) {
|
||||
if !remainder.is_empty() && SI_COMPATIBLE_NAMES.contains(remainder) {
|
||||
return Some((prefix, &input[prefix.name.len()..]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_km() {
|
||||
let result = match_short_prefix("km");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "kilo");
|
||||
assert_eq!(prefix.factor, 1e3);
|
||||
assert_eq!(base, "m");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_mw_is_milliwatt() {
|
||||
let result = match_short_prefix("mW");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "milli");
|
||||
assert_eq!(base, "W");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_mega_w() {
|
||||
let result = match_short_prefix("MW");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "mega");
|
||||
assert_eq!(base, "W");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_micro_s_greek_mu() {
|
||||
let result = match_short_prefix("\u{03BC}s");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "micro");
|
||||
assert_eq!(base, "s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_micro_sign() {
|
||||
let result = match_short_prefix("\u{00B5}s");
|
||||
assert!(result.is_some());
|
||||
let (prefix, _) = result.unwrap();
|
||||
assert_eq!(prefix.name, "micro");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_ns() {
|
||||
let result = match_short_prefix("ns");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "nano");
|
||||
assert_eq!(base, "s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_gb() {
|
||||
let result = match_short_prefix("GB");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "giga");
|
||||
assert_eq!(base, "B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_tb() {
|
||||
let result = match_short_prefix("TB");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "tera");
|
||||
assert_eq!(base, "B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_cm() {
|
||||
let result = match_short_prefix("cm");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "centi");
|
||||
assert_eq!(base, "m");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_mega_pa() {
|
||||
let result = match_short_prefix("MPa");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "mega");
|
||||
assert_eq!(base, "Pa");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_giga_j() {
|
||||
let result = match_short_prefix("GJ");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "giga");
|
||||
assert_eq!(base, "J");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_no_match_standalone() {
|
||||
assert!(match_short_prefix("k").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_prefix_no_match_incompatible() {
|
||||
assert!(match_short_prefix("kft").is_none());
|
||||
}
|
||||
|
||||
// ─── Long-form tests ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_kilometers() {
|
||||
let result = match_long_prefix("kilometers");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "kilo");
|
||||
assert_eq!(base, "meters");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_milligrams() {
|
||||
let result = match_long_prefix("milligrams");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "milli");
|
||||
assert_eq!(base, "grams");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_megawatts() {
|
||||
let result = match_long_prefix("megawatts");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "mega");
|
||||
assert_eq!(base, "watts");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_nanoseconds() {
|
||||
let result = match_long_prefix("nanoseconds");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "nano");
|
||||
assert_eq!(base, "seconds");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_terabytes() {
|
||||
let result = match_long_prefix("terabytes");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "tera");
|
||||
assert_eq!(base, "bytes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_microseconds() {
|
||||
let result = match_long_prefix("microseconds");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "micro");
|
||||
assert_eq!(base, "seconds");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_centimeters() {
|
||||
let result = match_long_prefix("centimeters");
|
||||
assert!(result.is_some());
|
||||
let (prefix, base) = result.unwrap();
|
||||
assert_eq!(prefix.name, "centi");
|
||||
assert_eq!(base, "meters");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_case_insensitive() {
|
||||
let result = match_long_prefix("Kilometers");
|
||||
assert!(result.is_some());
|
||||
let (prefix, _) = result.unwrap();
|
||||
assert_eq!(prefix.name, "kilo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_no_match_kilofahrenheit() {
|
||||
assert!(match_long_prefix("kilofahrenheit").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_prefix_no_match_kilomiles() {
|
||||
assert!(match_long_prefix("kilomiles").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_prefix_factors() {
|
||||
let expected = vec![
|
||||
("tera", 1e12),
|
||||
("giga", 1e9),
|
||||
("mega", 1e6),
|
||||
("kilo", 1e3),
|
||||
("centi", 1e-2),
|
||||
("milli", 1e-3),
|
||||
("micro", 1e-6),
|
||||
("nano", 1e-9),
|
||||
];
|
||||
for (name, factor) in expected {
|
||||
let prefix = SI_PREFIXES.iter().find(|p| p.name == name);
|
||||
assert!(prefix.is_some(), "Missing prefix: {}", name);
|
||||
assert_eq!(prefix.unwrap().factor, factor, "Wrong factor for {}", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
425
calcpad-engine/src/variables/aggregators.rs
Normal file
425
calcpad-engine/src/variables/aggregators.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
//! Section aggregators for CalcPad sheets.
|
||||
//!
|
||||
//! Provides aggregation keywords (`sum`, `total`, `subtotal`, `average`/`avg`,
|
||||
//! `min`, `max`, `count`) that operate over a section of lines, and
|
||||
//! `grand total` which sums all subtotal results.
|
||||
//!
|
||||
//! A **section** is bounded by:
|
||||
//! - Headings (lines starting with `#` followed by a space, e.g. `## Budget`)
|
||||
//! - Other aggregator lines
|
||||
//! - Start of document
|
||||
//!
|
||||
//! Only lines with numeric results are included in aggregation. Comments,
|
||||
//! blank lines, and error lines are skipped.
|
||||
|
||||
use crate::span::Span;
|
||||
use crate::types::{CalcResult, CalcValue};
|
||||
|
||||
/// The kind of aggregation to perform.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AggregatorKind {
|
||||
/// Sum of numeric values in the section. Also used for `total`.
|
||||
Sum,
|
||||
/// Distinct subtotal (tracked separately for grand total).
|
||||
Subtotal,
|
||||
/// Arithmetic mean of numeric values in the section.
|
||||
Average,
|
||||
/// Minimum numeric value in the section.
|
||||
Min,
|
||||
/// Maximum numeric value in the section.
|
||||
Max,
|
||||
/// Count of lines with numeric results in the section.
|
||||
Count,
|
||||
/// Sum of all subtotal values seen so far in the document.
|
||||
GrandTotal,
|
||||
}
|
||||
|
||||
/// Check if a trimmed line is a heading (e.g., `## Monthly Costs`).
|
||||
pub fn is_heading(line: &str) -> bool {
|
||||
let trimmed = line.trim();
|
||||
// Match lines starting with one or more '#' followed by a space
|
||||
let bytes = trimmed.as_bytes();
|
||||
if bytes.is_empty() || bytes[0] != b'#' {
|
||||
return false;
|
||||
}
|
||||
let mut i = 0;
|
||||
while i < bytes.len() && bytes[i] == b'#' {
|
||||
i += 1;
|
||||
}
|
||||
// Must have at least one # and be followed by a space (or be only #s)
|
||||
i > 0 && i <= 6 && (i >= bytes.len() || bytes[i] == b' ')
|
||||
}
|
||||
|
||||
/// Detect if a trimmed line is a standalone aggregator keyword.
|
||||
/// Returns the aggregator kind, or None if the line is not an aggregator.
|
||||
pub fn detect_aggregator(line: &str) -> Option<AggregatorKind> {
|
||||
let trimmed = line.trim().to_lowercase();
|
||||
|
||||
// Check for two-word "grand total" first
|
||||
if trimmed == "grand total" {
|
||||
return Some(AggregatorKind::GrandTotal);
|
||||
}
|
||||
|
||||
match trimmed.as_str() {
|
||||
"sum" => Some(AggregatorKind::Sum),
|
||||
"total" => Some(AggregatorKind::Sum),
|
||||
"subtotal" => Some(AggregatorKind::Subtotal),
|
||||
"average" | "avg" => Some(AggregatorKind::Average),
|
||||
"min" => Some(AggregatorKind::Min),
|
||||
"max" => Some(AggregatorKind::Max),
|
||||
"count" => Some(AggregatorKind::Count),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a line is a section boundary (heading or aggregator).
|
||||
pub fn is_section_boundary(line: &str) -> bool {
|
||||
is_heading(line) || detect_aggregator(line).is_some()
|
||||
}
|
||||
|
||||
/// Collect numeric values from the section above the given line index.
|
||||
///
|
||||
/// Walks backwards from the line before `line_index` (0-indexed) until hitting
|
||||
/// a section boundary (heading, another aggregator, or start of document).
|
||||
///
|
||||
/// Only non-error results with extractable numeric values are included.
|
||||
pub fn collect_section_values(
|
||||
results: &[CalcResult],
|
||||
sources: &[String],
|
||||
line_index: usize,
|
||||
) -> Vec<f64> {
|
||||
let mut values = Vec::new();
|
||||
|
||||
if line_index == 0 {
|
||||
return values;
|
||||
}
|
||||
|
||||
for i in (0..line_index).rev() {
|
||||
let source = &sources[i];
|
||||
|
||||
// Stop at headings
|
||||
if is_heading(source) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop at other aggregator lines
|
||||
if detect_aggregator(source).is_some() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Extract numeric value from result
|
||||
if let Some(val) = extract_numeric_value(&results[i]) {
|
||||
values.push(val);
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse to document order (we collected bottom-up)
|
||||
values.reverse();
|
||||
values
|
||||
}
|
||||
|
||||
/// Extract a numeric value from a CalcResult, if it has one.
|
||||
fn extract_numeric_value(result: &CalcResult) -> Option<f64> {
|
||||
match &result.value {
|
||||
CalcValue::Number { value } => Some(*value),
|
||||
CalcValue::UnitValue { value, .. } => Some(*value),
|
||||
CalcValue::CurrencyValue { amount, .. } => Some(*amount),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute an aggregation over the given values.
|
||||
pub fn compute_aggregation(kind: AggregatorKind, values: &[f64], span: Span) -> CalcResult {
|
||||
match kind {
|
||||
AggregatorKind::GrandTotal => {
|
||||
// Grand total is handled separately (sums subtotals, not section values)
|
||||
let sum: f64 = values.iter().sum();
|
||||
CalcResult::number(sum, span)
|
||||
}
|
||||
_ => {
|
||||
if values.is_empty() {
|
||||
return CalcResult::number(0.0, span);
|
||||
}
|
||||
match kind {
|
||||
AggregatorKind::Sum | AggregatorKind::Subtotal => {
|
||||
let sum: f64 = values.iter().sum();
|
||||
CalcResult::number(sum, span)
|
||||
}
|
||||
AggregatorKind::Average => {
|
||||
let sum: f64 = values.iter().sum();
|
||||
CalcResult::number(sum / values.len() as f64, span)
|
||||
}
|
||||
AggregatorKind::Min => {
|
||||
let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
|
||||
CalcResult::number(min, span)
|
||||
}
|
||||
AggregatorKind::Max => {
|
||||
let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
CalcResult::number(max, span)
|
||||
}
|
||||
AggregatorKind::Count => {
|
||||
CalcResult::number(values.len() as f64, span)
|
||||
}
|
||||
AggregatorKind::GrandTotal => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a grand total from a list of subtotal values.
|
||||
pub fn compute_grand_total(subtotal_values: &[f64], span: Span) -> CalcResult {
|
||||
let sum: f64 = subtotal_values.iter().sum();
|
||||
CalcResult::number(sum, span)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// --- is_heading ---
|
||||
|
||||
#[test]
|
||||
fn test_heading_h1() {
|
||||
assert!(is_heading("# Title"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heading_h2() {
|
||||
assert!(is_heading("## Subtitle"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_heading_h3_with_whitespace() {
|
||||
assert!(is_heading(" ### Indented Heading "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_heading_hash_in_middle() {
|
||||
assert!(!is_heading("this is #not a heading"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_heading_hash_ref() {
|
||||
assert!(!is_heading("#1 * 2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_heading_empty() {
|
||||
assert!(!is_heading(""));
|
||||
}
|
||||
|
||||
// --- detect_aggregator ---
|
||||
|
||||
#[test]
|
||||
fn test_detect_sum() {
|
||||
assert_eq!(detect_aggregator("sum"), Some(AggregatorKind::Sum));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_total() {
|
||||
assert_eq!(detect_aggregator("total"), Some(AggregatorKind::Sum));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_subtotal() {
|
||||
assert_eq!(detect_aggregator("subtotal"), Some(AggregatorKind::Subtotal));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_average() {
|
||||
assert_eq!(detect_aggregator("average"), Some(AggregatorKind::Average));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_avg() {
|
||||
assert_eq!(detect_aggregator("avg"), Some(AggregatorKind::Average));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_min() {
|
||||
assert_eq!(detect_aggregator("min"), Some(AggregatorKind::Min));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_max() {
|
||||
assert_eq!(detect_aggregator("max"), Some(AggregatorKind::Max));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_count() {
|
||||
assert_eq!(detect_aggregator("count"), Some(AggregatorKind::Count));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_grand_total() {
|
||||
assert_eq!(detect_aggregator("grand total"), Some(AggregatorKind::GrandTotal));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_case_insensitive() {
|
||||
assert_eq!(detect_aggregator(" SUM "), Some(AggregatorKind::Sum));
|
||||
assert_eq!(detect_aggregator(" Grand Total "), Some(AggregatorKind::GrandTotal));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_not_aggregator() {
|
||||
assert_eq!(detect_aggregator("sum + 5"), None);
|
||||
assert_eq!(detect_aggregator("total expense"), None);
|
||||
assert_eq!(detect_aggregator("x = 5"), None);
|
||||
}
|
||||
|
||||
// --- collect_section_values ---
|
||||
|
||||
#[test]
|
||||
fn test_collect_section_basic() {
|
||||
let results = vec![
|
||||
CalcResult::number(10.0, Span::new(0, 2)),
|
||||
CalcResult::number(20.0, Span::new(0, 2)),
|
||||
CalcResult::number(30.0, Span::new(0, 2)),
|
||||
];
|
||||
let sources = vec!["10".to_string(), "20".to_string(), "30".to_string()];
|
||||
let values = collect_section_values(&results, &sources, 3);
|
||||
assert_eq!(values, vec![10.0, 20.0, 30.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_section_stops_at_heading() {
|
||||
let results = vec![
|
||||
CalcResult::number(10.0, Span::new(0, 2)),
|
||||
CalcResult::error("heading", Span::new(0, 8)),
|
||||
CalcResult::number(20.0, Span::new(0, 2)),
|
||||
CalcResult::number(30.0, Span::new(0, 2)),
|
||||
];
|
||||
let sources = vec![
|
||||
"10".to_string(),
|
||||
"## Section".to_string(),
|
||||
"20".to_string(),
|
||||
"30".to_string(),
|
||||
];
|
||||
let values = collect_section_values(&results, &sources, 4);
|
||||
assert_eq!(values, vec![20.0, 30.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_section_stops_at_aggregator() {
|
||||
let results = vec![
|
||||
CalcResult::number(10.0, Span::new(0, 2)),
|
||||
CalcResult::number(20.0, Span::new(0, 2)),
|
||||
CalcResult::number(30.0, Span::new(0, 3)), // sum result
|
||||
CalcResult::number(40.0, Span::new(0, 2)),
|
||||
CalcResult::number(50.0, Span::new(0, 2)),
|
||||
];
|
||||
let sources = vec![
|
||||
"10".to_string(),
|
||||
"20".to_string(),
|
||||
"sum".to_string(),
|
||||
"40".to_string(),
|
||||
"50".to_string(),
|
||||
];
|
||||
let values = collect_section_values(&results, &sources, 5);
|
||||
assert_eq!(values, vec![40.0, 50.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_section_skips_errors() {
|
||||
let results = vec![
|
||||
CalcResult::number(10.0, Span::new(0, 2)),
|
||||
CalcResult::error("parse error", Span::new(0, 3)),
|
||||
CalcResult::number(30.0, Span::new(0, 2)),
|
||||
];
|
||||
let sources = vec!["10".to_string(), "???".to_string(), "30".to_string()];
|
||||
let values = collect_section_values(&results, &sources, 3);
|
||||
assert_eq!(values, vec![10.0, 30.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_section_empty() {
|
||||
let results = vec![
|
||||
CalcResult::error("heading", Span::new(0, 8)),
|
||||
];
|
||||
let sources = vec!["## Section".to_string()];
|
||||
let values = collect_section_values(&results, &sources, 1);
|
||||
assert!(values.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collect_at_start() {
|
||||
let values = collect_section_values(&[], &[], 0);
|
||||
assert!(values.is_empty());
|
||||
}
|
||||
|
||||
// --- compute_aggregation ---
|
||||
|
||||
#[test]
|
||||
fn test_sum_aggregation() {
|
||||
let values = vec![10.0, 20.0, 30.0, 40.0];
|
||||
let result = compute_aggregation(AggregatorKind::Sum, &values, Span::new(0, 3));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 100.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subtotal_aggregation() {
|
||||
let values = vec![10.0, 20.0, 30.0];
|
||||
let result = compute_aggregation(AggregatorKind::Subtotal, &values, Span::new(0, 8));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 60.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_average_aggregation() {
|
||||
let values = vec![10.0, 20.0, 30.0];
|
||||
let result = compute_aggregation(AggregatorKind::Average, &values, Span::new(0, 7));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 20.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_aggregation() {
|
||||
let values = vec![5.0, 12.0, 3.0, 8.0];
|
||||
let result = compute_aggregation(AggregatorKind::Min, &values, Span::new(0, 3));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 3.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_aggregation() {
|
||||
let values = vec![5.0, 12.0, 3.0, 8.0];
|
||||
let result = compute_aggregation(AggregatorKind::Max, &values, Span::new(0, 3));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 12.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_aggregation() {
|
||||
let values = vec![5.0, 12.0, 3.0, 8.0];
|
||||
let result = compute_aggregation(AggregatorKind::Count, &values, Span::new(0, 5));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 4.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_section_returns_zero() {
|
||||
let result = compute_aggregation(AggregatorKind::Sum, &[], Span::new(0, 3));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 0.0 });
|
||||
|
||||
let result = compute_aggregation(AggregatorKind::Average, &[], Span::new(0, 7));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 0.0 });
|
||||
}
|
||||
|
||||
// --- compute_grand_total ---
|
||||
|
||||
#[test]
|
||||
fn test_grand_total_two_sections() {
|
||||
let subtotals = vec![300.0, 125.0];
|
||||
let result = compute_grand_total(&subtotals, Span::new(0, 11));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 425.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grand_total_empty() {
|
||||
let result = compute_grand_total(&[], Span::new(0, 11));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 0.0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grand_total_includes_zero_subtotals() {
|
||||
let subtotals = vec![300.0, 0.0, 125.0];
|
||||
let result = compute_grand_total(&subtotals, Span::new(0, 11));
|
||||
assert_eq!(result.value, CalcValue::Number { value: 425.0 });
|
||||
}
|
||||
}
|
||||
552
calcpad-engine/src/variables/autocomplete.rs
Normal file
552
calcpad-engine/src/variables/autocomplete.rs
Normal file
@@ -0,0 +1,552 @@
|
||||
//! Autocomplete provider for CalcPad.
|
||||
//!
|
||||
//! Provides completion suggestions for variables, functions, keywords, and units
|
||||
//! based on the current cursor position and sheet content.
|
||||
//!
|
||||
//! This module is purely text-based — it does not depend on the evaluation engine.
|
||||
//! It scans the sheet content for variable declarations and matches against
|
||||
//! built-in registries of functions, keywords, and units.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The kind of completion item.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum CompletionKind {
|
||||
/// A user-declared variable.
|
||||
Variable,
|
||||
/// A built-in math function.
|
||||
Function,
|
||||
/// An aggregator keyword (sum, total, etc.).
|
||||
Keyword,
|
||||
/// A unit suffix (kg, km, etc.).
|
||||
Unit,
|
||||
}
|
||||
|
||||
/// A single autocomplete suggestion.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct CompletionItem {
|
||||
/// Display label for the suggestion.
|
||||
pub label: String,
|
||||
/// Text to insert when the suggestion is accepted.
|
||||
pub insert_text: String,
|
||||
/// Category of the completion.
|
||||
pub kind: CompletionKind,
|
||||
/// Optional description/detail.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
/// Context for computing autocomplete suggestions.
|
||||
pub struct CompletionContext<'a> {
|
||||
/// Current line text.
|
||||
pub line: &'a str,
|
||||
/// Cursor position within the line (0-indexed byte offset).
|
||||
pub cursor: usize,
|
||||
/// Full sheet content (all lines joined by newlines).
|
||||
pub sheet_content: &'a str,
|
||||
/// Current line number (1-indexed).
|
||||
pub line_number: usize,
|
||||
}
|
||||
|
||||
/// Result of an autocomplete query.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CompletionResult {
|
||||
/// Matching completion items.
|
||||
pub items: Vec<CompletionItem>,
|
||||
/// The prefix being matched.
|
||||
pub prefix: String,
|
||||
/// Start position for text replacement in the line (byte offset).
|
||||
pub replace_start: usize,
|
||||
/// End position for text replacement in the line (byte offset).
|
||||
pub replace_end: usize,
|
||||
}
|
||||
|
||||
/// Info about the extracted prefix at the cursor.
|
||||
struct PrefixInfo {
|
||||
prefix: String,
|
||||
start: usize,
|
||||
end: usize,
|
||||
is_unit_context: bool,
|
||||
}
|
||||
|
||||
// --- Built-in registries ---
|
||||
|
||||
fn keyword_completions() -> Vec<CompletionItem> {
|
||||
vec![
|
||||
CompletionItem {
|
||||
label: "sum".to_string(),
|
||||
insert_text: "sum".to_string(),
|
||||
kind: CompletionKind::Keyword,
|
||||
detail: Some("Sum of section values".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "total".to_string(),
|
||||
insert_text: "total".to_string(),
|
||||
kind: CompletionKind::Keyword,
|
||||
detail: Some("Total of section values".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "subtotal".to_string(),
|
||||
insert_text: "subtotal".to_string(),
|
||||
kind: CompletionKind::Keyword,
|
||||
detail: Some("Subtotal of section values".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "average".to_string(),
|
||||
insert_text: "average".to_string(),
|
||||
kind: CompletionKind::Keyword,
|
||||
detail: Some("Average of section values".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "count".to_string(),
|
||||
insert_text: "count".to_string(),
|
||||
kind: CompletionKind::Keyword,
|
||||
detail: Some("Count of section values".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "prev".to_string(),
|
||||
insert_text: "prev".to_string(),
|
||||
kind: CompletionKind::Keyword,
|
||||
detail: Some("Previous line result".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn function_completions() -> Vec<CompletionItem> {
|
||||
vec![
|
||||
CompletionItem {
|
||||
label: "sqrt".to_string(),
|
||||
insert_text: "sqrt(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Square root".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "abs".to_string(),
|
||||
insert_text: "abs(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Absolute value".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "round".to_string(),
|
||||
insert_text: "round(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Round to nearest integer".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "floor".to_string(),
|
||||
insert_text: "floor(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Round down".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "ceil".to_string(),
|
||||
insert_text: "ceil(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Round up".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "log".to_string(),
|
||||
insert_text: "log(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Base-10 logarithm".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "ln".to_string(),
|
||||
insert_text: "ln(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Natural logarithm".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "sin".to_string(),
|
||||
insert_text: "sin(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Sine".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "cos".to_string(),
|
||||
insert_text: "cos(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Cosine".to_string()),
|
||||
},
|
||||
CompletionItem {
|
||||
label: "tan".to_string(),
|
||||
insert_text: "tan(".to_string(),
|
||||
kind: CompletionKind::Function,
|
||||
detail: Some("Tangent".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn unit_completions() -> Vec<CompletionItem> {
|
||||
vec![
|
||||
// Mass
|
||||
CompletionItem { label: "kg".to_string(), insert_text: "kg".to_string(), kind: CompletionKind::Unit, detail: Some("Kilograms".to_string()) },
|
||||
CompletionItem { label: "lb".to_string(), insert_text: "lb".to_string(), kind: CompletionKind::Unit, detail: Some("Pounds".to_string()) },
|
||||
CompletionItem { label: "oz".to_string(), insert_text: "oz".to_string(), kind: CompletionKind::Unit, detail: Some("Ounces".to_string()) },
|
||||
CompletionItem { label: "mg".to_string(), insert_text: "mg".to_string(), kind: CompletionKind::Unit, detail: Some("Milligrams".to_string()) },
|
||||
// Length
|
||||
CompletionItem { label: "km".to_string(), insert_text: "km".to_string(), kind: CompletionKind::Unit, detail: Some("Kilometers".to_string()) },
|
||||
CompletionItem { label: "mm".to_string(), insert_text: "mm".to_string(), kind: CompletionKind::Unit, detail: Some("Millimeters".to_string()) },
|
||||
CompletionItem { label: "cm".to_string(), insert_text: "cm".to_string(), kind: CompletionKind::Unit, detail: Some("Centimeters".to_string()) },
|
||||
CompletionItem { label: "ft".to_string(), insert_text: "ft".to_string(), kind: CompletionKind::Unit, detail: Some("Feet".to_string()) },
|
||||
CompletionItem { label: "in".to_string(), insert_text: "in".to_string(), kind: CompletionKind::Unit, detail: Some("Inches".to_string()) },
|
||||
// Volume
|
||||
CompletionItem { label: "ml".to_string(), insert_text: "ml".to_string(), kind: CompletionKind::Unit, detail: Some("Milliliters".to_string()) },
|
||||
// Data
|
||||
CompletionItem { label: "kB".to_string(), insert_text: "kB".to_string(), kind: CompletionKind::Unit, detail: Some("Kilobytes".to_string()) },
|
||||
CompletionItem { label: "MB".to_string(), insert_text: "MB".to_string(), kind: CompletionKind::Unit, detail: Some("Megabytes".to_string()) },
|
||||
CompletionItem { label: "GB".to_string(), insert_text: "GB".to_string(), kind: CompletionKind::Unit, detail: Some("Gigabytes".to_string()) },
|
||||
CompletionItem { label: "TB".to_string(), insert_text: "TB".to_string(), kind: CompletionKind::Unit, detail: Some("Terabytes".to_string()) },
|
||||
// Time
|
||||
CompletionItem { label: "ms".to_string(), insert_text: "ms".to_string(), kind: CompletionKind::Unit, detail: Some("Milliseconds".to_string()) },
|
||||
CompletionItem { label: "hr".to_string(), insert_text: "hr".to_string(), kind: CompletionKind::Unit, detail: Some("Hours".to_string()) },
|
||||
]
|
||||
}
|
||||
|
||||
// --- Prefix extraction ---
|
||||
|
||||
/// Extract the identifier prefix at the cursor position.
|
||||
///
|
||||
/// Handles unit context detection: when letters immediately follow digits
|
||||
/// (e.g., "50km"), the prefix is the letter portion ("km") and is_unit_context is true.
|
||||
fn extract_prefix(line: &str, cursor: usize) -> Option<PrefixInfo> {
|
||||
if cursor == 0 || cursor > line.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let bytes = line.as_bytes();
|
||||
|
||||
// Walk backwards collecting word characters (alphanumeric + underscore)
|
||||
let mut start = cursor;
|
||||
while start > 0 && is_word_char(bytes[start - 1]) {
|
||||
start -= 1;
|
||||
}
|
||||
|
||||
if start == cursor {
|
||||
return None;
|
||||
}
|
||||
|
||||
let full_word = &line[start..cursor];
|
||||
|
||||
// If the word starts with a letter or underscore, it's a normal identifier prefix
|
||||
if full_word.as_bytes()[0].is_ascii_alphabetic() || full_word.as_bytes()[0] == b'_' {
|
||||
return Some(PrefixInfo {
|
||||
prefix: full_word.to_string(),
|
||||
start,
|
||||
end: cursor,
|
||||
is_unit_context: false,
|
||||
});
|
||||
}
|
||||
|
||||
// Word starts with digits — find where letters begin for unit context
|
||||
let letter_start = full_word
|
||||
.bytes()
|
||||
.position(|b| b.is_ascii_alphabetic() || b == b'_');
|
||||
|
||||
match letter_start {
|
||||
Some(offset) => {
|
||||
let prefix = full_word[offset..].to_string();
|
||||
let abs_start = start + offset;
|
||||
Some(PrefixInfo {
|
||||
prefix,
|
||||
start: abs_start,
|
||||
end: cursor,
|
||||
is_unit_context: true,
|
||||
})
|
||||
}
|
||||
None => None, // All digits, no completion
|
||||
}
|
||||
}
|
||||
|
||||
fn is_word_char(b: u8) -> bool {
|
||||
b.is_ascii_alphanumeric() || b == b'_'
|
||||
}
|
||||
|
||||
// --- Variable extraction ---
|
||||
|
||||
/// Extract declared variable names from the sheet content.
|
||||
/// Scans each line for `identifier = expression` patterns.
|
||||
/// Excludes the current line to avoid self-reference.
|
||||
fn extract_variables(sheet_content: &str, current_line_number: usize) -> Vec<CompletionItem> {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
let mut items = Vec::new();
|
||||
|
||||
for (i, line) in sheet_content.lines().enumerate() {
|
||||
let line_num = i + 1; // 1-indexed
|
||||
if line_num == current_line_number {
|
||||
continue;
|
||||
}
|
||||
|
||||
let trimmed = line.trim();
|
||||
if let Some(name) = extract_variable_name(trimmed) {
|
||||
if !seen.contains(&name) {
|
||||
seen.insert(name.clone());
|
||||
items.push(CompletionItem {
|
||||
label: name.clone(),
|
||||
insert_text: name,
|
||||
kind: CompletionKind::Variable,
|
||||
detail: Some(format!("Variable (line {})", line_num)),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
/// Extract the variable name from an assignment line.
|
||||
/// Returns Some(name) if the line matches `identifier = ...`.
|
||||
fn extract_variable_name(line: &str) -> Option<String> {
|
||||
let bytes = line.as_bytes();
|
||||
if bytes.is_empty() || (!bytes[0].is_ascii_alphabetic() && bytes[0] != b'_') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
while i < bytes.len() && is_word_char(bytes[i]) {
|
||||
i += 1;
|
||||
}
|
||||
let name = &line[..i];
|
||||
|
||||
// Skip whitespace
|
||||
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Must be followed by '=' but not '=='
|
||||
if i < bytes.len() && bytes[i] == b'=' {
|
||||
if i + 1 < bytes.len() && bytes[i + 1] == b'=' {
|
||||
return None; // comparison, not assignment
|
||||
}
|
||||
return Some(name.to_string());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// --- Core completion function ---
|
||||
|
||||
/// Get autocomplete suggestions for the current cursor position.
|
||||
///
|
||||
/// Returns `None` if:
|
||||
/// - The prefix is less than 2 characters
|
||||
/// - No completions match the prefix
|
||||
pub fn get_completions(context: &CompletionContext) -> Option<CompletionResult> {
|
||||
let prefix_info = extract_prefix(context.line, context.cursor)?;
|
||||
|
||||
// Enforce 2+ character minimum threshold
|
||||
if prefix_info.prefix.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let lower_prefix = prefix_info.prefix.to_lowercase();
|
||||
|
||||
let candidates = if prefix_info.is_unit_context {
|
||||
// Unit context: only suggest units
|
||||
unit_completions()
|
||||
} else {
|
||||
// General context: suggest variables, functions, and keywords
|
||||
let variables = extract_variables(context.sheet_content, context.line_number);
|
||||
let mut all = variables;
|
||||
all.extend(function_completions());
|
||||
all.extend(keyword_completions());
|
||||
all
|
||||
};
|
||||
|
||||
// Filter by case-insensitive prefix match
|
||||
let mut filtered: Vec<CompletionItem> = candidates
|
||||
.into_iter()
|
||||
.filter(|item| item.label.to_lowercase().starts_with(&lower_prefix))
|
||||
.collect();
|
||||
|
||||
if filtered.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Sort: exact case match first, then alphabetical
|
||||
let prefix_clone = prefix_info.prefix.clone();
|
||||
filtered.sort_by(|a, b| {
|
||||
let a_exact = if a.label.starts_with(&prefix_clone) { 0 } else { 1 };
|
||||
let b_exact = if b.label.starts_with(&prefix_clone) { 0 } else { 1 };
|
||||
if a_exact != b_exact {
|
||||
return a_exact.cmp(&b_exact);
|
||||
}
|
||||
a.label.cmp(&b.label)
|
||||
});
|
||||
|
||||
Some(CompletionResult {
|
||||
items: filtered,
|
||||
prefix: prefix_info.prefix,
|
||||
replace_start: prefix_info.start,
|
||||
replace_end: prefix_info.end,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_context<'a>(
|
||||
line: &'a str,
|
||||
cursor: usize,
|
||||
sheet: &'a str,
|
||||
line_number: usize,
|
||||
) -> CompletionContext<'a> {
|
||||
CompletionContext {
|
||||
line,
|
||||
cursor,
|
||||
sheet_content: sheet,
|
||||
line_number,
|
||||
}
|
||||
}
|
||||
|
||||
// --- AC1: Variable suggestions for 2+ character prefix ---
|
||||
|
||||
#[test]
|
||||
fn test_variable_suggestions() {
|
||||
let sheet = "monthly_rent = 1250\nmonthly_insurance = 200\nmortgage_payment = 800\n";
|
||||
let ctx = make_context("mo", 2, sheet, 4);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
assert_eq!(result.prefix, "mo");
|
||||
assert!(result.items.len() >= 2);
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"monthly_rent"));
|
||||
assert!(labels.contains(&"monthly_insurance"));
|
||||
assert!(labels.contains(&"mortgage_payment"));
|
||||
}
|
||||
|
||||
// --- AC4: Built-in function suggestions ---
|
||||
|
||||
#[test]
|
||||
fn test_function_suggestions_sq() {
|
||||
let ctx = make_context("sq", 2, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"sqrt"));
|
||||
}
|
||||
|
||||
// --- AC5: No suggestions for single character ---
|
||||
|
||||
#[test]
|
||||
fn test_no_suggestions_single_char() {
|
||||
let ctx = make_context("m", 1, "", 1);
|
||||
let result = get_completions(&ctx);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
// --- AC6: No suggestions when nothing matches ---
|
||||
|
||||
#[test]
|
||||
fn test_no_suggestions_no_match() {
|
||||
let ctx = make_context("zzzz", 4, "", 1);
|
||||
let result = get_completions(&ctx);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
// --- AC7: Unit context after number ---
|
||||
|
||||
#[test]
|
||||
fn test_unit_context_after_number() {
|
||||
let ctx = make_context("50km", 4, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
assert!(result.items.iter().all(|i| i.kind == CompletionKind::Unit));
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"km"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unit_context_kg() {
|
||||
let ctx = make_context("50kg", 4, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"kg"));
|
||||
}
|
||||
|
||||
// --- Edge cases ---
|
||||
|
||||
#[test]
|
||||
fn test_empty_line() {
|
||||
let ctx = make_context("", 0, "", 1);
|
||||
let result = get_completions(&ctx);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cursor_at_start() {
|
||||
let ctx = make_context("sum", 0, "", 1);
|
||||
let result = get_completions(&ctx);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_excludes_current_line_variables() {
|
||||
let sheet = "my_var = 10\nmy_other = 20";
|
||||
// Cursor is on line 1 typing "my" — should not suggest my_var from line 1
|
||||
let ctx = make_context("my", 2, sheet, 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(!labels.contains(&"my_var")); // excluded: same line
|
||||
assert!(labels.contains(&"my_other")); // included: different line
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keyword_suggestions() {
|
||||
let ctx = make_context("su", 2, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"sum"));
|
||||
assert!(labels.contains(&"subtotal"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace_range() {
|
||||
let ctx = make_context("x + sq", 6, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
assert_eq!(result.prefix, "sq");
|
||||
assert_eq!(result.replace_start, 4);
|
||||
assert_eq!(result.replace_end, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prev_suggestion() {
|
||||
let ctx = make_context("pr", 2, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"prev"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive_matching() {
|
||||
let ctx = make_context("SU", 2, "", 1);
|
||||
let result = get_completions(&ctx).unwrap();
|
||||
let labels: Vec<&str> = result.items.iter().map(|i| i.label.as_str()).collect();
|
||||
assert!(labels.contains(&"sum"));
|
||||
assert!(labels.contains(&"subtotal"));
|
||||
}
|
||||
|
||||
// --- Variable extraction ---
|
||||
|
||||
#[test]
|
||||
fn test_extract_variable_name_valid() {
|
||||
assert_eq!(extract_variable_name("x = 5"), Some("x".to_string()));
|
||||
assert_eq!(
|
||||
extract_variable_name("tax_rate = 0.15"),
|
||||
Some("tax_rate".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extract_variable_name("_temp = 100"),
|
||||
Some("_temp".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extract_variable_name("item1 = 42"),
|
||||
Some("item1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_variable_name_invalid() {
|
||||
assert_eq!(extract_variable_name("5 + 3"), None);
|
||||
assert_eq!(extract_variable_name("== 5"), None);
|
||||
assert_eq!(extract_variable_name(""), None);
|
||||
assert_eq!(extract_variable_name("x == 5"), None);
|
||||
}
|
||||
}
|
||||
39
calcpad-engine/src/variables/mod.rs
Normal file
39
calcpad-engine/src/variables/mod.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
//! Variables, line references, aggregators, and autocomplete for CalcPad.
|
||||
//!
|
||||
//! This module provides the features from Epic 5 (Variables, Line References &
|
||||
//! Aggregators) that extend the CalcPad engine beyond simple per-line evaluation:
|
||||
//!
|
||||
//! - **Line references** (`line1`, `#1`): Reference the result of a specific line
|
||||
//! by number, with renumbering support when lines are inserted/deleted and
|
||||
//! circular reference detection.
|
||||
//!
|
||||
//! - **Aggregators** (`sum`, `total`, `subtotal`, `average`/`avg`, `min`, `max`,
|
||||
//! `count`, `grand total`): Compute over a section of lines bounded by headings
|
||||
//! or other aggregator lines.
|
||||
//!
|
||||
//! - **Autocomplete**: Provides completion suggestions for variables, functions,
|
||||
//! keywords, and units based on prefix matching (2+ characters).
|
||||
//!
|
||||
//! Note: Variable declaration/usage (`x = 5`, then `x * 2`) and previous-line
|
||||
//! references (`prev`, `ans`) are handled by the core engine modules:
|
||||
//! - `context.rs` / `EvalContext` — stores variables and resolves `__prev`
|
||||
//! - `sheet_context.rs` / `SheetContext` — manages multi-line evaluation with
|
||||
//! dependency tracking, storing line results as `__line_N` variables
|
||||
//! - `interpreter.rs` — evaluates `LineRef`, `PrevRef`, and `FunctionCall` AST nodes
|
||||
//! - `lexer.rs` / `parser.rs` — tokenize and parse `lineN`, `#N`, `prev`, `ans`
|
||||
|
||||
pub mod aggregators;
|
||||
pub mod autocomplete;
|
||||
pub mod references;
|
||||
|
||||
// Re-export key types for convenience.
|
||||
pub use aggregators::{
|
||||
AggregatorKind, collect_section_values, compute_aggregation, compute_grand_total,
|
||||
detect_aggregator, is_heading, is_section_boundary,
|
||||
};
|
||||
pub use autocomplete::{
|
||||
get_completions, CompletionContext, CompletionItem, CompletionKind, CompletionResult,
|
||||
};
|
||||
pub use references::{
|
||||
detect_circular_line_refs, extract_line_refs, renumber_after_delete, renumber_after_insert,
|
||||
};
|
||||
365
calcpad-engine/src/variables/references.rs
Normal file
365
calcpad-engine/src/variables/references.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
//! Line reference support for CalcPad.
|
||||
//!
|
||||
//! Provides line references (`line1`, `#1`) that resolve to the result of a
|
||||
//! specific line by number, and renumbering logic for when lines are inserted
|
||||
//! or deleted.
|
||||
//!
|
||||
//! Line references are 1-indexed (matching what users see in the editor).
|
||||
//! Internally they are stored in the EvalContext as `__line_N` variables.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Extract all line reference numbers from an expression string.
|
||||
/// Recognizes both `lineN` and `#N` syntax (case-insensitive for "line").
|
||||
pub fn extract_line_refs(input: &str) -> Vec<usize> {
|
||||
let mut refs = Vec::new();
|
||||
let bytes = input.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut i = 0;
|
||||
|
||||
while i < len {
|
||||
// Check for #N syntax
|
||||
if bytes[i] == b'#' && i + 1 < len && bytes[i + 1].is_ascii_digit() {
|
||||
i += 1;
|
||||
let start = i;
|
||||
while i < len && bytes[i].is_ascii_digit() {
|
||||
i += 1;
|
||||
}
|
||||
if let Ok(n) = input[start..i].parse::<usize>() {
|
||||
if !refs.contains(&n) {
|
||||
refs.push(n);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for lineN syntax (case-insensitive)
|
||||
if i + 4 < len {
|
||||
let word = &input[i..i + 4];
|
||||
if word.eq_ignore_ascii_case("line") {
|
||||
let after = i + 4;
|
||||
if after < len && bytes[after].is_ascii_digit() {
|
||||
// Check that the character before is not alphanumeric (word boundary)
|
||||
if i == 0 || !bytes[i - 1].is_ascii_alphanumeric() {
|
||||
let num_start = after;
|
||||
let mut j = after;
|
||||
while j < len && bytes[j].is_ascii_digit() {
|
||||
j += 1;
|
||||
}
|
||||
// Check that the character after is not alphanumeric (word boundary)
|
||||
if j >= len || !bytes[j].is_ascii_alphanumeric() {
|
||||
if let Ok(n) = input[num_start..j].parse::<usize>() {
|
||||
if !refs.contains(&n) {
|
||||
refs.push(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
i = j;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
refs
|
||||
}
|
||||
|
||||
/// Update line references in an expression string after a line insertion.
|
||||
///
|
||||
/// When a new line is inserted at position `insert_at` (1-indexed),
|
||||
/// all references to lines at or after that position are incremented by 1.
|
||||
pub fn renumber_after_insert(input: &str, insert_at: usize) -> String {
|
||||
renumber_refs(input, |line_num| {
|
||||
if line_num >= insert_at {
|
||||
line_num + 1
|
||||
} else {
|
||||
line_num
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Update line references in an expression string after a line deletion.
|
||||
///
|
||||
/// When a line is deleted at position `delete_at` (1-indexed),
|
||||
/// references to the deleted line become 0 (invalid).
|
||||
/// References to lines after the deleted one are decremented by 1.
|
||||
pub fn renumber_after_delete(input: &str, delete_at: usize) -> String {
|
||||
renumber_refs(input, |line_num| {
|
||||
if line_num == delete_at {
|
||||
0 // mark as invalid
|
||||
} else if line_num > delete_at {
|
||||
line_num - 1
|
||||
} else {
|
||||
line_num
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply a renumbering function to all line references in an expression string.
|
||||
/// Handles both `lineN` and `#N` syntax.
|
||||
fn renumber_refs<F>(input: &str, transform: F) -> String
|
||||
where
|
||||
F: Fn(usize) -> usize,
|
||||
{
|
||||
let mut result = String::with_capacity(input.len());
|
||||
let bytes = input.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut i = 0;
|
||||
|
||||
while i < len {
|
||||
// Check for #N syntax
|
||||
if bytes[i] == b'#' && i + 1 < len && bytes[i + 1].is_ascii_digit() {
|
||||
result.push('#');
|
||||
i += 1;
|
||||
let start = i;
|
||||
while i < len && bytes[i].is_ascii_digit() {
|
||||
i += 1;
|
||||
}
|
||||
if let Ok(n) = input[start..i].parse::<usize>() {
|
||||
let new_n = transform(n);
|
||||
result.push_str(&new_n.to_string());
|
||||
} else {
|
||||
result.push_str(&input[start..i]);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for lineN syntax (case-insensitive)
|
||||
if i + 4 < len {
|
||||
let word = &input[i..i + 4];
|
||||
if word.eq_ignore_ascii_case("line") {
|
||||
let after = i + 4;
|
||||
if after < len && bytes[after].is_ascii_digit() {
|
||||
if i == 0 || !bytes[i - 1].is_ascii_alphanumeric() {
|
||||
let prefix = &input[i..i + 4]; // preserve original case
|
||||
let num_start = after;
|
||||
let mut j = after;
|
||||
while j < len && bytes[j].is_ascii_digit() {
|
||||
j += 1;
|
||||
}
|
||||
if j >= len || !bytes[j].is_ascii_alphanumeric() {
|
||||
if let Ok(n) = input[num_start..j].parse::<usize>() {
|
||||
let new_n = transform(n);
|
||||
result.push_str(prefix);
|
||||
result.push_str(&new_n.to_string());
|
||||
i = j;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Safe: input is valid UTF-8, push the character
|
||||
let ch = input[i..].chars().next().unwrap();
|
||||
result.push(ch);
|
||||
i += ch.len_utf8();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Detect circular line references given a set of lines and their line reference dependencies.
|
||||
///
|
||||
/// Returns the set of line numbers (1-indexed) that are involved in circular references.
|
||||
pub fn detect_circular_line_refs(
|
||||
line_refs: &[(usize, Vec<usize>)], // (line_number, referenced_lines)
|
||||
) -> HashSet<usize> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
|
||||
for (line_num, refs) in line_refs {
|
||||
adj.insert(*line_num, refs.clone());
|
||||
}
|
||||
|
||||
let mut circular = HashSet::new();
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
enum State {
|
||||
Unvisited,
|
||||
InProgress,
|
||||
Done,
|
||||
}
|
||||
|
||||
let mut state: HashMap<usize, State> = HashMap::new();
|
||||
for (line_num, _) in line_refs {
|
||||
state.insert(*line_num, State::Unvisited);
|
||||
}
|
||||
|
||||
fn dfs(
|
||||
node: usize,
|
||||
adj: &HashMap<usize, Vec<usize>>,
|
||||
state: &mut HashMap<usize, State>,
|
||||
path: &mut Vec<usize>,
|
||||
circular: &mut HashSet<usize>,
|
||||
) {
|
||||
if let Some(&s) = state.get(&node) {
|
||||
if s == State::Done {
|
||||
return;
|
||||
}
|
||||
if s == State::InProgress {
|
||||
// Found a cycle — mark all nodes in the cycle
|
||||
if let Some(start_idx) = path.iter().position(|&n| n == node) {
|
||||
for &n in &path[start_idx..] {
|
||||
circular.insert(n);
|
||||
}
|
||||
}
|
||||
circular.insert(node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
state.insert(node, State::InProgress);
|
||||
path.push(node);
|
||||
|
||||
if let Some(deps) = adj.get(&node) {
|
||||
for &dep in deps {
|
||||
if adj.contains_key(&dep) {
|
||||
dfs(dep, adj, state, path, circular);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
path.pop();
|
||||
state.insert(node, State::Done);
|
||||
}
|
||||
|
||||
for (line_num, _) in line_refs {
|
||||
if state.get(line_num) == Some(&State::Unvisited) {
|
||||
let mut path = Vec::new();
|
||||
dfs(*line_num, &adj, &mut state, &mut path, &mut circular);
|
||||
}
|
||||
}
|
||||
|
||||
circular
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// --- extract_line_refs ---
|
||||
|
||||
#[test]
|
||||
fn test_extract_hash_ref() {
|
||||
assert_eq!(extract_line_refs("#1 * 2"), vec![1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_line_ref() {
|
||||
assert_eq!(extract_line_refs("line1 * 2"), vec![1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_line_ref_case_insensitive() {
|
||||
assert_eq!(extract_line_refs("Line3 + Line1"), vec![3, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_multiple_refs() {
|
||||
assert_eq!(extract_line_refs("#1 + #2 * line3"), vec![1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_no_refs() {
|
||||
assert_eq!(extract_line_refs("x + 5"), Vec::<usize>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_dedup() {
|
||||
assert_eq!(extract_line_refs("#1 + #1"), vec![1]);
|
||||
}
|
||||
|
||||
// --- renumber_after_insert ---
|
||||
|
||||
#[test]
|
||||
fn test_renumber_insert_shifts_refs_at_or_after() {
|
||||
assert_eq!(renumber_after_insert("#1 + #2", 1), "#2 + #3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_renumber_insert_no_shift_before() {
|
||||
assert_eq!(renumber_after_insert("#1 + #2", 3), "#1 + #2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_renumber_insert_line_syntax() {
|
||||
assert_eq!(renumber_after_insert("line2 * 3", 1), "line3 * 3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_renumber_insert_mixed() {
|
||||
assert_eq!(renumber_after_insert("#1 + line3", 2), "#1 + line4");
|
||||
}
|
||||
|
||||
// --- renumber_after_delete ---
|
||||
|
||||
#[test]
|
||||
fn test_renumber_delete_marks_deleted_as_zero() {
|
||||
assert_eq!(renumber_after_delete("#2 * 3", 2), "#0 * 3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_renumber_delete_shifts_after() {
|
||||
assert_eq!(renumber_after_delete("#1 + #3", 2), "#1 + #2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_renumber_delete_line_syntax() {
|
||||
assert_eq!(renumber_after_delete("line3 + 5", 2), "line2 + 5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_renumber_delete_no_change_before() {
|
||||
assert_eq!(renumber_after_delete("#1 + #2", 5), "#1 + #2");
|
||||
}
|
||||
|
||||
// --- detect_circular_line_refs ---
|
||||
|
||||
#[test]
|
||||
fn test_no_cycles() {
|
||||
let refs = vec![(1, vec![]), (2, vec![1]), (3, vec![1, 2])];
|
||||
let circular = detect_circular_line_refs(&refs);
|
||||
assert!(circular.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_direct_cycle() {
|
||||
let refs = vec![(1, vec![2]), (2, vec![1])];
|
||||
let circular = detect_circular_line_refs(&refs);
|
||||
assert!(circular.contains(&1));
|
||||
assert!(circular.contains(&2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transitive_cycle() {
|
||||
let refs = vec![(1, vec![2]), (2, vec![3]), (3, vec![1])];
|
||||
let circular = detect_circular_line_refs(&refs);
|
||||
assert!(circular.contains(&1));
|
||||
assert!(circular.contains(&2));
|
||||
assert!(circular.contains(&3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_self_reference() {
|
||||
let refs = vec![(1, vec![1])];
|
||||
let circular = detect_circular_line_refs(&refs);
|
||||
assert!(circular.contains(&1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_cycle() {
|
||||
// Line 4 depends on line 2 which is in a cycle, but line 4 is not in the cycle itself
|
||||
let refs = vec![(1, vec![2]), (2, vec![1]), (3, vec![]), (4, vec![2])];
|
||||
let circular = detect_circular_line_refs(&refs);
|
||||
assert!(circular.contains(&1));
|
||||
assert!(circular.contains(&2));
|
||||
assert!(!circular.contains(&3));
|
||||
// Line 4 is not in the cycle (it just references a cyclic line)
|
||||
assert!(!circular.contains(&4));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user