Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Supports multiple servers in single client #310

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ local_addr = "127.0.0.1:22" # 需要被转发的服务的地址
```toml
[client]
remote_addr = "example.com:2333" # Necessary. The address of the server
# remote_addr = "foo.com:2333,bar.com:3332" # Multiple server addresses, this is a preview feature and address changes cannot be hot loaded now
default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones
heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer heartbeat test. The value must be greater than `server.heartbeat_interval`. Default: 40 seconds
retry_interval = 1 # Optional. The interval between retry to connect to the server. Default: 1 second
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Here is the full configuration specification:
```toml
[client]
remote_addr = "example.com:2333" # Necessary. The address of the server
# remote_addr = "foo.com:2333,bar.com:3332" # Multiple server addresses, this is a preview feature and address changes cannot be hot loaded now
default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones
heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer heartbeat test. The value must be greater than `server.heartbeat_interval`. Default: 40 seconds
retry_interval = 1 # Optional. The interval between retry to connect to the server. Default: 1 second
Expand Down
11 changes: 11 additions & 0 deletions examples/minimal/client.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
[client]
name = "pve"
remote_addr = "localhost:2333"
default_token = "123"

[client.services.foo1]
local_addr = "127.0.0.1:80"

[servers.hk]
remote_addr = "localhost:2333"
default_token = "123"

[servers.cn]
remote_addr = "localhost:2444"
default_token = "123"


13 changes: 13 additions & 0 deletions examples/minimal/server.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,16 @@ default_token = "123"

[server.services.foo1]
bind_addr = "0.0.0.0:5202"

[proxies.pve.ssh]
bind_addr = "0.0.0.0:5001"
local_addr = "127.0.0.1:22"

[proxies.pve.rdp]
bind_addr = "0.0.0.0:5002"
local_addr = "127.0.0.1:3389"

[proxies.router.web]
bind_addr = "0.0.0.0:5003"
local_addr = "127.0.0.1:80"

69 changes: 43 additions & 26 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
use crate::config::{ClientConfig, ClientServerConfig, ClientServiceConfig, Config, ServiceType, TransportType};
use crate::config_watcher::{ClientServiceChange, ConfigChange};
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *};
Expand Down Expand Up @@ -35,21 +35,23 @@ pub async fn run_client(
shutdown_rx: broadcast::Receiver<bool>,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = config.client.ok_or_else(|| {
let client_config = config.client.ok_or_else(|| {
anyhow!(
"Try to run as a client, but the configuration is missing. Please add the `[client]` block"
)
})?;

match config.transport.transport_type {
let servers_config = config.servers.unwrap_or_else(|| HashMap::new());

match client_config.transport.transport_type {
TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(config).await?;
let mut client = Client::<TcpTransport>::from(client_config, servers_config).await?;
client.run(shutdown_rx, update_rx).await
}
TransportType::Tls => {
#[cfg(feature = "tls")]
{
let mut client = Client::<TlsTransport>::from(config).await?;
let mut client = Client::<TlsTransport>::from(client_config, servers_config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(feature = "tls"))]
Expand All @@ -58,7 +60,7 @@ pub async fn run_client(
TransportType::Noise => {
#[cfg(feature = "noise")]
{
let mut client = Client::<NoiseTransport>::from(config).await?;
let mut client = Client::<NoiseTransport>::from(client_config, servers_config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(feature = "noise"))]
Expand All @@ -67,7 +69,7 @@ pub async fn run_client(
TransportType::Websocket => {
#[cfg(feature = "websocket")]
{
let mut client = Client::<WebsocketTransport>::from(config).await?;
let mut client = Client::<WebsocketTransport>::from(client_config, servers_config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(feature = "websocket"))]
Expand All @@ -82,17 +84,19 @@ type Nonce = protocol::Digest;
// Holds the state of a client
struct Client<T: Transport> {
config: ClientConfig,
servers: HashMap<String, ClientServerConfig>,
service_handles: HashMap<String, ControlChannelHandle>,
transport: Arc<T>,
}

impl<T: 'static + Transport> Client<T> {
// Create a Client from `[client]` config block
async fn from(config: ClientConfig) -> Result<Client<T>> {
async fn from(config: ClientConfig, servers: HashMap<String, ClientServerConfig>) -> Result<Client<T>> {
let transport =
Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?);
Ok(Client {
config,
servers,
service_handles: HashMap::new(),
transport,
})
Expand All @@ -104,15 +108,20 @@ impl<T: 'static + Transport> Client<T> {
mut shutdown_rx: broadcast::Receiver<bool>,
mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
// Create a control channel for each service defined
let handle = ControlChannelHandle::new(
(*config).clone(),
self.config.remote_addr.clone(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
self.service_handles.insert(name.clone(), handle);
for (server_name, server_config) in self.servers.iter() {
info!("server_name={}, server_config={:?}", server_name, server_config);
let remote_addr = &server_config.remote_addr;
for (service_name, service_config) in &self.config.services {
// Create a control channel for each service defined
let handle = ControlChannelHandle::new(
(*service_config).clone(),
remote_addr.to_string(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
let full_name = server_name.to_string() + "," + service_name;
self.service_handles.insert(full_name, handle);
}
}

// Wait for the shutdown signal
Expand Down Expand Up @@ -147,17 +156,25 @@ impl<T: 'static + Transport> Client<T> {
match e {
ConfigChange::ClientChange(client_change) => match client_change {
ClientServiceChange::Add(cfg) => {
let name = cfg.name.clone();
let handle = ControlChannelHandle::new(
cfg,
self.config.remote_addr.clone(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
let _ = self.service_handles.insert(name, handle);
for (server_name, server_config) in self.servers.iter() {
info!("server_name={}, server_config={:?}", server_name, server_config);
let remote_addr = &server_config.remote_addr;
let handle = ControlChannelHandle::new(
cfg.clone(),
remote_addr.to_string(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
let full_name = server_name.to_string() + "," + &cfg.name;
let _ = self.service_handles.insert(full_name, handle);
}
}
ClientServiceChange::Delete(s) => {
let _ = self.service_handles.remove(&s);
for (server_name, server_config) in self.servers.iter() {
info!("server_name={}, server_config={:?}", server_name, server_config);
let full_name = server_name.to_string() + "," + &s;
let _ = self.service_handles.remove(&full_name);
}
}
},
ignored => warn!("Ignored {:?} since running as a client", ignored),
Expand Down
23 changes: 23 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ fn default_client_retry_interval() -> u64 {
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
pub struct ClientConfig {
pub name: Option<String>,
pub remote_addr: String,
pub default_token: Option<MaskedString>,
pub services: HashMap<String, ClientServiceConfig>,
Expand All @@ -210,6 +211,19 @@ pub struct ClientConfig {
pub retry_interval: u64,
}

#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
pub struct ClientServerConfig {
pub remote_addr: String,
pub default_token: Option<MaskedString>,
#[serde(default)]
pub transport: TransportConfig,
#[serde(default = "default_heartbeat_timeout")]
pub heartbeat_timeout: u64,
#[serde(default = "default_client_retry_interval")]
pub retry_interval: u64,
}

fn default_heartbeat_interval() -> u64 {
DEFAULT_HEARTBEAT_INTERVAL_SECS
}
Expand All @@ -226,11 +240,20 @@ pub struct ServerConfig {
pub heartbeat_interval: u64,
}

#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
pub struct ProxyConfig {
pub bind_addr: String,
pub local_addr: String,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub server: Option<ServerConfig>,
pub proxies: Option<HashMap<String, HashMap<String, ProxyConfig>>>,
pub client: Option<ClientConfig>,
pub servers: Option<HashMap<String, ClientServerConfig>>,
}

impl Config {
Expand Down
43 changes: 29 additions & 14 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
use crate::config::{Config, ProxyConfig, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
use crate::config_watcher::{ConfigChange, ServerServiceChange};
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::helper::retry_notify_with_deadline;
Expand Down Expand Up @@ -44,22 +44,24 @@ pub async fn run_server(
shutdown_rx: broadcast::Receiver<bool>,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = match config.server {
let server_config = match config.server {
Some(config) => config,
None => {
return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
}
};

match config.transport.transport_type {
let proxy_config = config.proxies.unwrap_or_else(|| HashMap::new());

match server_config.transport.transport_type {
TransportType::Tcp => {
let mut server = Server::<TcpTransport>::from(config).await?;
let mut server = Server::<TcpTransport>::from(server_config, proxy_config).await?;
server.run(shutdown_rx, update_rx).await?;
}
TransportType::Tls => {
#[cfg(feature = "tls")]
{
let mut server = Server::<TlsTransport>::from(config).await?;
let mut server = Server::<TlsTransport>::from(server_config, proxy_config).await?;
server.run(shutdown_rx, update_rx).await?;
}
#[cfg(not(feature = "tls"))]
Expand All @@ -68,7 +70,7 @@ pub async fn run_server(
TransportType::Noise => {
#[cfg(feature = "noise")]
{
let mut server = Server::<NoiseTransport>::from(config).await?;
let mut server = Server::<NoiseTransport>::from(server_config, proxy_config).await?;
server.run(shutdown_rx, update_rx).await?;
}
#[cfg(not(feature = "noise"))]
Expand All @@ -77,7 +79,7 @@ pub async fn run_server(
TransportType::Websocket => {
#[cfg(feature = "websocket")]
{
let mut server = Server::<WebsocketTransport>::from(config).await?;
let mut server = Server::<WebsocketTransport>::from(server_config, proxy_config).await?;
server.run(shutdown_rx, update_rx).await?;
}
#[cfg(not(feature = "websocket"))]
Expand All @@ -96,7 +98,7 @@ type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<
struct Server<T: Transport> {
// `[server]` config
config: Arc<ServerConfig>,

proxies: Arc<HashMap<String, HashMap<String, ProxyConfig>>>,
// `[server.services]` config, indexed by ServiceDigest
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
// Collection of contorl channels
Expand All @@ -118,13 +120,15 @@ fn generate_service_hashmap(

impl<T: 'static + Transport> Server<T> {
// Create a server from `[server]`
pub async fn from(config: ServerConfig) -> Result<Server<T>> {
let config = Arc::new(config);
let services = Arc::new(RwLock::new(generate_service_hashmap(&config)));
pub async fn from(server_config: ServerConfig, proxies: HashMap<String, HashMap<String, ProxyConfig>>) -> Result<Server<T>> {
let server_config = Arc::new(server_config);
let proxies = Arc::new(proxies);
let services = Arc::new(RwLock::new(generate_service_hashmap(&server_config)));
let control_channels = Arc::new(RwLock::new(ControlChannelMap::new()));
let transport = Arc::new(T::new(&config.transport)?);
let transport = Arc::new(T::new(&server_config.transport)?);
Ok(Server {
config,
config: server_config,
proxies,
services,
control_channels,
transport,
Expand Down Expand Up @@ -152,6 +156,14 @@ impl<T: 'static + Transport> Server<T> {
..Default::default()
};

// Log initial services
let proxies = self.proxies.clone();
for (client_name, services) in proxies.iter() {
for (service_name, service_config) in services.iter() {
info!("client_name={}, service_name={}: {:?}", client_name, service_name, service_config);
}
}

// Wait for connections and shutdown signals
loop {
tokio::select! {
Expand Down Expand Up @@ -185,10 +197,11 @@ impl<T: 'static + Transport> Server<T> {
match conn.with_context(|| "Failed to do transport handshake") {
Ok(conn) => {
let services = self.services.clone();
let proxies = self.proxies.clone();
let control_channels = self.control_channels.clone();
let server_config = self.config.clone();
tokio::spawn(async move {
if let Err(err) = handle_connection(conn, services, control_channels, server_config).await {
if let Err(err) = handle_connection(conn, services, proxies, control_channels, server_config).await {
error!("{:#}", err);
}
}.instrument(info_span!("connection", %addr)));
Expand Down Expand Up @@ -250,13 +263,15 @@ impl<T: 'static + Transport> Server<T> {
async fn handle_connection<T: 'static + Transport>(
mut conn: T::Stream,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
proxies: Arc<HashMap<String, HashMap<String, ProxyConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
server_config: Arc<ServerConfig>,
) -> Result<()> {
// Read hello
let hello = read_hello(&mut conn).await?;
match hello {
ControlChannelHello(_, service_digest) => {
debug!("server proxies: {:?}", proxies);
do_control_channel_handshake(
conn,
services,
Expand Down
Loading