Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Aug 29, 2024
1 parent 913652e commit c0f17ab
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 63 deletions.
66 changes: 31 additions & 35 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use windows_sys::{
/// Wrapper around a <https://git.zx2c4.com/wintun/about/#wintun_adapter_handle>
pub struct Adapter {
adapter: UnsafeHandle<wintun_raw::WINTUN_ADAPTER_HANDLE>,
wintun: Wintun,
pub(crate) wintun: Wintun,
guid: u128,
index: u32,
luid: NET_LUID_LH,
Expand Down Expand Up @@ -89,18 +89,17 @@ impl Adapter {
let result = unsafe { wintun.WintunCreateAdapter(name_utf16.as_ptr(), tunnel_type_utf16.as_ptr(), &guid_s) };

if result.is_null() {
Err("Failed to create adapter".into())
} else {
let luid = crate::ffi::alias_to_luid(&name_utf16)?;
let index = crate::ffi::luid_to_index(&luid)?;
Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
guid,
index,
luid,
}))
return Err("Failed to create adapter".into());
}
let luid = crate::ffi::alias_to_luid(&name_utf16)?;
let index = crate::ffi::luid_to_index(&luid)?;
Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
guid,
index,
luid,
}))
}

/// Attempts to open an existing wintun interface name `name`.
Expand All @@ -112,20 +111,19 @@ impl Adapter {
let result = unsafe { wintun.WintunOpenAdapter(name_utf16.as_ptr()) };

if result.is_null() {
Err("WintunOpenAdapter failed".into())
} else {
let luid = crate::ffi::alias_to_luid(&name_utf16)?;
let index = crate::ffi::luid_to_index(&luid)?;
let guid = crate::ffi::luid_to_guid(&luid)?;
let guid = util::win_guid_to_u128(&guid);
Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
guid,
index,
luid,
}))
return Err("WintunOpenAdapter failed".into());
}
let luid = crate::ffi::alias_to_luid(&name_utf16)?;
let index = crate::ffi::luid_to_index(&luid)?;
let guid = crate::ffi::luid_to_guid(&luid)?;
let guid = util::win_guid_to_u128(&guid);
Ok(Arc::new(Adapter {
adapter: UnsafeHandle(result),
wintun: wintun.clone(),
guid,
index,
luid,
}))
}

/// Delete an adapter, consuming it in the process
Expand Down Expand Up @@ -153,17 +151,15 @@ impl Adapter {
let result = unsafe { self.wintun.WintunStartSession(self.adapter.0, capacity) };

if result.is_null() {
Err("WintunStartSession failed".into())
} else {
let shutdown_event = unsafe { CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null_mut()) };
Ok(session::Session {
session: UnsafeHandle(result),
wintun: self.wintun.clone(),
read_event: OnceLock::new(),
shutdown_event: UnsafeHandle(shutdown_event),
adapter: Arc::clone(self),
})
return Err("WintunStartSession failed".into());
}
let shutdown_event = unsafe { CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null_mut()) };
Ok(session::Session {
session: UnsafeHandle(result),
read_event: OnceLock::new(),
shutdown_event: UnsafeHandle(shutdown_event),
adapter: self.clone(),
})
}

/// Returns the Win32 LUID for this adapter
Expand Down
2 changes: 1 addition & 1 deletion src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Drop for Packet {
// ring buffer that the wintun session owns. We return that region of
// memory back to wintun here
self.session
.wintun
.get_wintun()
.WintunReleaseReceivePacket(self.session.session.0, self.bytes.as_ptr())
};
}
Expand Down
57 changes: 30 additions & 27 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ pub struct Session {
/// The session handle given to us by WintunStartSession
pub(crate) session: UnsafeHandle<wintun_raw::WINTUN_SESSION_HANDLE>,

/// Shared dll for required wintun driver functions
pub(crate) wintun: Wintun,

/// Windows event handle that is signaled by the wintun driver when data becomes available to
/// read
pub(crate) read_event: OnceLock<UnsafeHandle<HANDLE>>,
Expand All @@ -36,6 +33,10 @@ impl Session {
self.adapter.clone()
}

pub(crate) fn get_wintun(&self) -> Wintun {
self.adapter.wintun.clone()
}

/// Allocates a send packet of the specified size. Wraps WintunAllocateSendPacket
///
/// All packets returned from this function must be sent using [`Session::send_packet`] because
Expand All @@ -44,25 +45,26 @@ impl Session {
/// up the send queue for all other packets allocated in the future. It is okay for the session
/// to shutdown with allocated packets that have not yet been sent
pub fn allocate_send_packet(self: &Arc<Self>, size: u16) -> Result<packet::Packet, Error> {
let ptr = unsafe { self.wintun.WintunAllocateSendPacket(self.session.0, size as u32) };
let wintun = self.get_wintun();
let ptr = unsafe { wintun.WintunAllocateSendPacket(self.session.0, size as u32) };
if ptr.is_null() {
Err(util::get_last_error()?.into())
} else {
Ok(packet::Packet {
//SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which
//must be less than isize::MAX because bytes is a u16
bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) },
session: self.clone(),
kind: packet::Kind::SendPacketPending,
})
return Err(util::get_last_error()?.into());
}
Ok(packet::Packet {
//SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which
//must be less than isize::MAX because bytes is a u16
bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) },
session: self.clone(),
kind: packet::Kind::SendPacketPending,
})
}

/// Sends a packet previously allocated with [`Session::allocate_send_packet`]
pub fn send_packet(&self, mut packet: packet::Packet) {
assert!(matches!(packet.kind, packet::Kind::SendPacketPending));

unsafe { self.wintun.WintunSendPacket(self.session.0, packet.bytes.as_ptr()) };
let wintun = self.get_wintun();
unsafe { wintun.WintunSendPacket(self.session.0, packet.bytes.as_ptr()) };
//Mark the packet at sent
packet.kind = packet::Kind::SendPacketSent;
}
Expand All @@ -73,33 +75,34 @@ impl Session {
pub fn try_receive(self: &Arc<Self>) -> Result<Option<packet::Packet>, Error> {
let mut size = 0u32;

let ptr = unsafe { self.wintun.WintunReceivePacket(self.session.0, &mut size as *mut u32) };
let wintun = self.get_wintun();
let ptr = unsafe { wintun.WintunReceivePacket(self.session.0, &mut size as *mut u32) };

debug_assert!(size <= u16::MAX as u32);
if ptr.is_null() {
//Wintun returns ERROR_NO_MORE_ITEMS instead of blocking if packets are not available
match unsafe { GetLastError() } {
return match unsafe { GetLastError() } {
ERROR_NO_MORE_ITEMS => Ok(None),
e => Err(std::io::Error::from_raw_os_error(e as i32).into()),
}
} else {
Ok(Some(packet::Packet {
kind: packet::Kind::ReceivePacket,
//SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which
//must be less than isize::MAX because bytes is a u16
bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) },
session: self.clone(),
}))
};
}
Ok(Some(packet::Packet {
kind: packet::Kind::ReceivePacket,
//SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which
//must be less than isize::MAX because bytes is a u16
bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) },
session: self.clone(),
}))
}

/// # Safety
/// Returns the low level read event handle that is signaled when more data becomes available
/// to read
pub unsafe fn get_read_wait_event(&self) -> Result<HANDLE, Error> {
let wintun = self.get_wintun();
Ok(self
.read_event
.get_or_init(|| UnsafeHandle(self.wintun.WintunGetReadWaitEvent(self.session.0) as _))
.get_or_init(|| UnsafeHandle(wintun.WintunGetReadWaitEvent(self.session.0) as _))
.0)
}

Expand Down Expand Up @@ -162,7 +165,7 @@ impl Drop for Session {
log::error!("Failed to close handle of shutdown event: {:?}", err);
}

unsafe { self.wintun.WintunEndSession(self.session.0) };
unsafe { self.get_wintun().WintunEndSession(self.session.0) };
self.session.0 = ptr::null_mut();
}
}

0 comments on commit c0f17ab

Please sign in to comment.