diff --git a/Cargo.toml b/Cargo.toml index 29f51e2..f29101a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,5 @@ license = "MIT" simple_logger = "2.1.0" log = "0.4.14" tokio = { version = "1", features = ["full"] } -getopts = "0.2" \ No newline at end of file +getopts = "0.2" +net2 = "0.2" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 80f59d2..08c82ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod utils; mod socks; use log::LevelFilter; +use net2::TcpStreamExt; use simple_logger::SimpleLogger; use getopts::Options; use tokio::{io::{self, AsyncWriteExt, AsyncReadExt}, task, net::{TcpListener, TcpStream}}; @@ -77,6 +78,11 @@ async fn main() -> io::Result<()> { loop{ let (stream , _) = listener.accept().await.unwrap(); + + let raw_stream = stream.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let stream = TcpStream::from_std(raw_stream).unwrap(); + tokio::spawn(async { socks::socksv5_handle(stream).await; }); @@ -119,7 +125,7 @@ async fn main() -> io::Result<()> { Ok(p) => p }; - let (mut slave_stream , slave_addr) = match slave_listener.accept().await{ + let (slave_stream , slave_addr) = match slave_listener.accept().await{ Err(e) => { log::error!("error : {}", e); return Ok(()); @@ -127,6 +133,10 @@ async fn main() -> io::Result<()> { Ok(p) => p }; + let raw_stream = slave_stream.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let mut slave_stream = TcpStream::from_std(raw_stream).unwrap(); + log::info!("accept slave from : {}:{}" , slave_addr.ip() , slave_addr.port() ); log::info!("listen to : {}" , socks_addr); @@ -140,7 +150,11 @@ async fn main() -> io::Result<()> { }; loop { - let (mut stream , _) = listener.accept().await.unwrap(); + let (stream , _) = listener.accept().await.unwrap(); + + let raw_stream = stream.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let mut stream = TcpStream::from_std(raw_stream).unwrap(); match slave_stream.write_all(&[MAGIC_FLAG[0]]).await{ Err(e) => { @@ -150,7 +164,7 @@ async fn main() -> io::Result<()> { _ => {} }; - let (mut proxy_stream , slave_addr) = match slave_listener.accept().await{ + let (proxy_stream , slave_addr) = match slave_listener.accept().await{ Err(e) => { log::error!("error : {}", e); return Ok(()); @@ -158,6 +172,10 @@ async fn main() -> io::Result<()> { Ok(p) => p }; + let raw_stream = proxy_stream.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let mut proxy_stream = TcpStream::from_std(raw_stream).unwrap(); + log::info!("accept from slave : {}:{}" , slave_addr.ip() , slave_addr.port() ); task::spawn(async move { @@ -222,13 +240,18 @@ async fn main() -> io::Result<()> { Ok(p) => p }; - let mut master_stream = match TcpStream::connect(fulladdr.clone()).await{ + let master_stream = match TcpStream::connect(fulladdr.clone()).await{ Err(e) => { log::error!("error : {}", e); return Ok(()); }, Ok(p) => p }; + + let raw_stream = master_stream.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let mut master_stream = TcpStream::from_std(raw_stream).unwrap(); + log::info!("connect to {} success" ,fulladdr ); loop { let mut buf = [0u8 ; 1]; @@ -249,6 +272,10 @@ async fn main() -> io::Result<()> { Ok(p) => p }; + let raw_stream = stream.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let stream = TcpStream::from_std(raw_stream).unwrap(); + task::spawn(async { socks::socksv5_handle(stream).await; }); diff --git a/src/socks.rs b/src/socks.rs index 124cb0d..b02eaa5 100644 --- a/src/socks.rs +++ b/src/socks.rs @@ -1,5 +1,6 @@ use std::{ net::{Ipv6Addr, SocketAddrV6}}; +use net2::TcpStreamExt; use tokio::{net::TcpStream, io::{AsyncWriteExt, AsyncReadExt}}; use crate::utils::makeword; @@ -56,7 +57,7 @@ async fn tcp_transfer(stream : &mut TcpStream , addr : &Addr, address : &String } }; - let mut client = match client { + let client = match client { Err(_) => { log::warn!("connect[{}] faild" , address); return; @@ -64,6 +65,10 @@ async fn tcp_transfer(stream : &mut TcpStream , addr : &Addr, address : &String Ok(p) => p }; + let raw_stream = client.into_std().unwrap(); + raw_stream.set_keepalive(Some(std::time::Duration::from_secs(10))).unwrap(); + let mut client = TcpStream::from_std(raw_stream).unwrap(); + let remote_port = client.local_addr().unwrap().port(); let mut reply = Vec::with_capacity(22);