From 17ddc89bd025dd0304b858f39a264fe707ea40bc Mon Sep 17 00:00:00 2001 From: rustdesk Date: Fri, 27 Jan 2023 11:00:59 +0800 Subject: [PATCH] sync rustdesk's hbb_common here --- libs/hbb_common/build.rs | 7 ++- libs/hbb_common/protos/message.proto | 35 +++++++++++-- libs/hbb_common/src/config.rs | 14 ++++-- libs/hbb_common/src/keyboard.rs | 39 +++++++++++++++ libs/hbb_common/src/lib.rs | 34 +++++++++++++ libs/hbb_common/src/platform/linux.rs | 8 +-- libs/hbb_common/src/socket_client.rs | 71 +++++++++++++++++++++++++-- 7 files changed, 192 insertions(+), 16 deletions(-) create mode 100644 libs/hbb_common/src/keyboard.rs diff --git a/libs/hbb_common/build.rs b/libs/hbb_common/build.rs index 225ec34..bff0cfa 100644 --- a/libs/hbb_common/build.rs +++ b/libs/hbb_common/build.rs @@ -1,8 +1,11 @@ fn main() { - std::fs::create_dir_all("src/protos").unwrap(); + let out_dir = format!("{}/protos", std::env::var("OUT_DIR").unwrap()); + + std::fs::create_dir_all(&out_dir).unwrap(); + protobuf_codegen::Codegen::new() .pure() - .out_dir("src/protos") + .out_dir(out_dir) .inputs(&["protos/rendezvous.proto", "protos/message.proto"]) .include("protos") .customize( diff --git a/libs/hbb_common/protos/message.proto b/libs/hbb_common/protos/message.proto index b127ac3..b7965f2 100644 --- a/libs/hbb_common/protos/message.proto +++ b/libs/hbb_common/protos/message.proto @@ -445,7 +445,7 @@ enum ImageQuality { } message VideoCodecState { - enum PerferCodec { + enum PreferCodec { Auto = 0; VPX = 1; H264 = 2; @@ -455,7 +455,7 @@ message VideoCodecState { int32 score_vpx = 1; int32 score_h264 = 2; int32 score_h265 = 3; - PerferCodec perfer = 4; + PreferCodec prefer = 4; } message OptionMessage { @@ -503,7 +503,7 @@ message AudioFrame { // Notify peer to show message box. message MessageBox { - // Message type. Refer to flutter/lib/commom.dart/msgBox(). + // Message type. Refer to flutter/lib/common.dart/msgBox(). string msgtype = 1; string title = 2; // English @@ -552,6 +552,29 @@ message BackNotification { } } +message ElevationRequestWithLogon { + string username = 1; + string password = 2; +} + +message ElevationRequest { + oneof union { + bool direct = 1; + ElevationRequestWithLogon logon = 2; + } +} + +message SwitchSidesRequest { + bytes uuid = 1; +} + +message SwitchSidesResponse { + bytes uuid = 1; + LoginRequest lr = 2; +} + +message SwitchBack {} + message Misc { oneof union { ChatMessage chat_message = 4; @@ -567,6 +590,11 @@ message Misc { bool uac = 15; bool foreground_window_elevated = 16; bool stop_service = 17; + ElevationRequest elevation_request = 18; + string elevation_response = 19; + bool portable_service_running = 20; + SwitchSidesRequest switch_sides_request = 21; + SwitchBack switch_back = 22; } } @@ -591,5 +619,6 @@ message Message { Misc misc = 19; Cliprdr cliprdr = 20; MessageBox message_box = 21; + SwitchSidesResponse switch_sides_response = 22; } } diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs index 1d427a2..20334ed 100644 --- a/libs/hbb_common/src/config.rs +++ b/libs/hbb_common/src/config.rs @@ -49,7 +49,10 @@ lazy_static::lazy_static! { static ref CONFIG2: Arc> = Arc::new(RwLock::new(Config2::load())); static ref LOCAL_CONFIG: Arc> = Arc::new(RwLock::new(LocalConfig::load())); pub static ref ONLINE: Arc>> = Default::default(); - pub static ref PROD_RENDEZVOUS_SERVER: Arc> = Default::default(); + pub static ref PROD_RENDEZVOUS_SERVER: Arc> = Arc::new(RwLock::new(match option_env!("RENDEZVOUS_SERVER") { + Some(key) if !key.is_empty() => key, + _ => "", + }.to_owned())); pub static ref APP_NAME: Arc> = Arc::new(RwLock::new("RustDesk".to_owned())); static ref KEY_PAIR: Arc, Vec)>>> = Default::default(); static ref HW_CODEC_CONFIG: Arc> = Arc::new(RwLock::new(HwCodecConfig::load())); @@ -77,12 +80,17 @@ const CHARS: &'static [char] = &[ 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ]; -pub const RENDEZVOUS_SERVERS: &'static [&'static str] = &[ +const RENDEZVOUS_SERVERS: &'static [&'static str] = &[ "rs-ny.rustdesk.com", "rs-sg.rustdesk.com", "rs-cn.rustdesk.com", ]; -pub const RS_PUB_KEY: &'static str = "OeVuKk5nlHiXp+APNn0Y3pC1Iwpwn44JGqrQCsWqmBw="; + +pub const RS_PUB_KEY: &'static str = match option_env!("RS_PUB_KEY") { + Some(key) if !key.is_empty() => key, + _ => "OeVuKk5nlHiXp+APNn0Y3pC1Iwpwn44JGqrQCsWqmBw=", +}; + pub const RENDEZVOUS_PORT: i32 = 21116; pub const RELAY_PORT: i32 = 21117; diff --git a/libs/hbb_common/src/keyboard.rs b/libs/hbb_common/src/keyboard.rs new file mode 100644 index 0000000..10979f5 --- /dev/null +++ b/libs/hbb_common/src/keyboard.rs @@ -0,0 +1,39 @@ +use std::{fmt, slice::Iter, str::FromStr}; + +use crate::protos::message::KeyboardMode; + +impl fmt::Display for KeyboardMode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + KeyboardMode::Legacy => write!(f, "legacy"), + KeyboardMode::Map => write!(f, "map"), + KeyboardMode::Translate => write!(f, "translate"), + KeyboardMode::Auto => write!(f, "auto"), + } + } +} + +impl FromStr for KeyboardMode { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + "legacy" => Ok(KeyboardMode::Legacy), + "map" => Ok(KeyboardMode::Map), + "translate" => Ok(KeyboardMode::Translate), + "auto" => Ok(KeyboardMode::Auto), + _ => Err(()), + } + } +} + +impl KeyboardMode { + pub fn iter() -> Iter<'static, KeyboardMode> { + static KEYBOARD_MODES: [KeyboardMode; 4] = [ + KeyboardMode::Legacy, + KeyboardMode::Map, + KeyboardMode::Translate, + KeyboardMode::Auto, + ]; + KEYBOARD_MODES.iter() + } +} diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs index 85e0100..e57994f 100644 --- a/libs/hbb_common/src/lib.rs +++ b/libs/hbb_common/src/lib.rs @@ -40,6 +40,7 @@ pub use tokio_socks::TargetAddr; pub mod password_security; pub use chrono; pub use directories_next; +pub mod keyboard; #[cfg(feature = "quic")] pub type Stream = quic::Connection; @@ -320,6 +321,18 @@ pub fn is_ip_str(id: &str) -> bool { is_ipv4_str(id) || is_ipv6_str(id) } +#[inline] +pub fn is_domain_port_str(id: &str) -> bool { + // modified regex for RFC1123 hostname. check https://stackoverflow.com/a/106223 for original version for hostname. + // according to [TLD List](https://data.iana.org/TLD/tlds-alpha-by-domain.txt) version 2023011700, + // there is no digits in TLD, and length is 2~63. + regex::Regex::new( + r"(?i)^([a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z-]{0,61}[a-z]:\d{1,5}$", + ) + .unwrap() + .is_match(id) +} + #[cfg(test)] mod test_lib { use super::*; @@ -339,4 +352,25 @@ mod test_lib { assert_eq!(is_ipv6_str("[1:2::0]:"), false); assert_eq!(is_ipv6_str("1:2::0]:1"), false); } + + #[test] + fn test_hostname_port() { + assert_eq!(is_domain_port_str("a:12"), false); + assert_eq!(is_domain_port_str("a.b.c:12"), false); + assert_eq!(is_domain_port_str("test.com:12"), true); + assert_eq!(is_domain_port_str("test-UPPER.com:12"), true); + assert_eq!(is_domain_port_str("some-other.domain.com:12"), true); + assert_eq!(is_domain_port_str("under_score:12"), false); + assert_eq!(is_domain_port_str("a@bc:12"), false); + assert_eq!(is_domain_port_str("1.1.1.1:12"), false); + assert_eq!(is_domain_port_str("1.2.3:12"), false); + assert_eq!(is_domain_port_str("1.2.3.45:12"), false); + assert_eq!(is_domain_port_str("a.b.c:123456"), false); + assert_eq!(is_domain_port_str("---:12"), false); + assert_eq!(is_domain_port_str(".:12"), false); + // todo: should we also check for these edge cases? + // out-of-range port + assert_eq!(is_domain_port_str("test.com:0"), true); + assert_eq!(is_domain_port_str("test.com:98989"), true); + } } diff --git a/libs/hbb_common/src/platform/linux.rs b/libs/hbb_common/src/platform/linux.rs index 4c6375d..e824163 100644 --- a/libs/hbb_common/src/platform/linux.rs +++ b/libs/hbb_common/src/platform/linux.rs @@ -1,15 +1,15 @@ use crate::ResultType; lazy_static::lazy_static! { - pub static ref DISTRO: Disto = Disto::new(); + pub static ref DISTRO: Distro = Distro::new(); } -pub struct Disto { +pub struct Distro { pub name: String, pub version_id: String, } -impl Disto { +impl Distro { fn new() -> Self { let name = run_cmds("awk -F'=' '/^NAME=/ {print $2}' /etc/os-release".to_owned()) .unwrap_or_default() @@ -74,7 +74,7 @@ fn get_display_server_of_session(session: &str) -> String { } else { "".to_owned() }; - if display_server.is_empty() { + if display_server.is_empty() || display_server == "tty" { // loginctl has not given the expected output. try something else. if let Ok(sestype) = std::env::var("XDG_SESSION_TYPE") { display_server = sestype; diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index b7cb137..6f62163 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -9,11 +9,48 @@ use std::net::SocketAddr; use tokio::net::ToSocketAddrs; use tokio_socks::{IntoTargetAddr, TargetAddr}; -pub fn test_if_valid_server(host: &str) -> String { - let mut host = host.to_owned(); - if !host.contains(":") { - host = format!("{}:{}", host, 0); +#[inline] +pub fn check_port(host: T, port: i32) -> String { + let host = host.to_string(); + if crate::is_ipv6_str(&host) { + if host.starts_with("[") { + return host; + } + return format!("[{}]:{}", host, port); } + if !host.contains(":") { + return format!("{}:{}", host, port); + } + return host; +} + +#[inline] +pub fn increase_port(host: T, offset: i32) -> String { + let host = host.to_string(); + if crate::is_ipv6_str(&host) { + if host.starts_with("[") { + let tmp: Vec<&str> = host.split("]:").collect(); + if tmp.len() == 2 { + let port: i32 = tmp[1].parse().unwrap_or(0); + if port > 0 { + return format!("{}]:{}", tmp[0], port + offset); + } + } + } + } else if host.contains(":") { + let tmp: Vec<&str> = host.split(":").collect(); + if tmp.len() == 2 { + let port: i32 = tmp[1].parse().unwrap_or(0); + if port > 0 { + return format!("{}:{}", tmp[0], port + offset); + } + } + } + return host; +} + +pub fn test_if_valid_server(host: &str) -> String { + let host = check_port(host, 0); use std::net::ToSocketAddrs; match Config::get_network_type() { @@ -216,4 +253,30 @@ mod tests { } assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err()); } + + #[test] + fn test_test_if_valid_server() { + assert!(!test_if_valid_server("a").is_empty()); + // on Linux, "1" is resolved to "0.0.0.1" + assert!(test_if_valid_server("1.1.1.1").is_empty()); + assert!(test_if_valid_server("1.1.1.1:1").is_empty()); + } + + #[test] + fn test_check_port() { + assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12"); + assert_eq!(check_port("1:2", 32), "[1:2]:32"); + assert_eq!(check_port("z1:2", 32), "z1:2"); + assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32"); + assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32"); + assert_eq!(check_port("test.com:32", 0), "test.com:32"); + assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13"); + assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13"); + assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4"); + assert_eq!(increase_port("test.com", 1), "test.com"); + assert_eq!(increase_port("test.com:13", 4), "test.com:17"); + assert_eq!(increase_port("1:13", 4), "1:13"); + assert_eq!(increase_port("22:1:13", 4), "22:1:13"); + assert_eq!(increase_port("z1:2", 1), "z1:3"); + } }