diff --git a/src/adapter.rs b/src/adapter.rs index ea3e472..16a5e49 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -30,7 +30,7 @@ use windows_sys::{ /// Wrapper around a pub struct Adapter { adapter: UnsafeHandle, - wintun: Wintun, + pub(crate) wintun: Wintun, guid: u128, index: u32, luid: NET_LUID_LH, @@ -89,18 +89,17 @@ impl Adapter { let result = unsafe { wintun.WintunCreateAdapter(name_utf16.as_ptr(), tunnel_type_utf16.as_ptr(), &guid_s) }; if result.is_null() { - Err("Failed to create adapter".into()) - } else { - let luid = crate::ffi::alias_to_luid(&name_utf16)?; - let index = crate::ffi::luid_to_index(&luid)?; - Ok(Arc::new(Adapter { - adapter: UnsafeHandle(result), - wintun: wintun.clone(), - guid, - index, - luid, - })) + return Err("Failed to create adapter".into()); } + let luid = crate::ffi::alias_to_luid(&name_utf16)?; + let index = crate::ffi::luid_to_index(&luid)?; + Ok(Arc::new(Adapter { + adapter: UnsafeHandle(result), + wintun: wintun.clone(), + guid, + index, + luid, + })) } /// Attempts to open an existing wintun interface name `name`. @@ -112,20 +111,19 @@ impl Adapter { let result = unsafe { wintun.WintunOpenAdapter(name_utf16.as_ptr()) }; if result.is_null() { - Err("WintunOpenAdapter failed".into()) - } else { - let luid = crate::ffi::alias_to_luid(&name_utf16)?; - let index = crate::ffi::luid_to_index(&luid)?; - let guid = crate::ffi::luid_to_guid(&luid)?; - let guid = util::win_guid_to_u128(&guid); - Ok(Arc::new(Adapter { - adapter: UnsafeHandle(result), - wintun: wintun.clone(), - guid, - index, - luid, - })) + return Err("WintunOpenAdapter failed".into()); } + let luid = crate::ffi::alias_to_luid(&name_utf16)?; + let index = crate::ffi::luid_to_index(&luid)?; + let guid = crate::ffi::luid_to_guid(&luid)?; + let guid = util::win_guid_to_u128(&guid); + Ok(Arc::new(Adapter { + adapter: UnsafeHandle(result), + wintun: wintun.clone(), + guid, + index, + luid, + })) } /// Delete an adapter, consuming it in the process @@ -153,17 +151,15 @@ impl Adapter { let result = unsafe { self.wintun.WintunStartSession(self.adapter.0, capacity) }; if result.is_null() { - Err("WintunStartSession failed".into()) - } else { - let shutdown_event = unsafe { CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null_mut()) }; - Ok(session::Session { - session: UnsafeHandle(result), - wintun: self.wintun.clone(), - read_event: OnceLock::new(), - shutdown_event: UnsafeHandle(shutdown_event), - adapter: Arc::clone(self), - }) + return Err("WintunStartSession failed".into()); } + let shutdown_event = unsafe { CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null_mut()) }; + Ok(session::Session { + session: UnsafeHandle(result), + read_event: OnceLock::new(), + shutdown_event: UnsafeHandle(shutdown_event), + adapter: self.clone(), + }) } /// Returns the Win32 LUID for this adapter diff --git a/src/packet.rs b/src/packet.rs index dada85f..2b8641f 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -67,7 +67,7 @@ impl Drop for Packet { // ring buffer that the wintun session owns. We return that region of // memory back to wintun here self.session - .wintun + .get_wintun() .WintunReleaseReceivePacket(self.session.session.0, self.bytes.as_ptr()) }; } diff --git a/src/session.rs b/src/session.rs index 07becbf..694c7fb 100644 --- a/src/session.rs +++ b/src/session.rs @@ -16,9 +16,6 @@ pub struct Session { /// The session handle given to us by WintunStartSession pub(crate) session: UnsafeHandle, - /// Shared dll for required wintun driver functions - pub(crate) wintun: Wintun, - /// Windows event handle that is signaled by the wintun driver when data becomes available to /// read pub(crate) read_event: OnceLock>, @@ -36,6 +33,10 @@ impl Session { self.adapter.clone() } + pub(crate) fn get_wintun(&self) -> Wintun { + self.adapter.wintun.clone() + } + /// Allocates a send packet of the specified size. Wraps WintunAllocateSendPacket /// /// All packets returned from this function must be sent using [`Session::send_packet`] because @@ -44,25 +45,26 @@ impl Session { /// up the send queue for all other packets allocated in the future. It is okay for the session /// to shutdown with allocated packets that have not yet been sent pub fn allocate_send_packet(self: &Arc, size: u16) -> Result { - let ptr = unsafe { self.wintun.WintunAllocateSendPacket(self.session.0, size as u32) }; + let wintun = self.get_wintun(); + let ptr = unsafe { wintun.WintunAllocateSendPacket(self.session.0, size as u32) }; if ptr.is_null() { - Err(util::get_last_error()?.into()) - } else { - Ok(packet::Packet { - //SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which - //must be less than isize::MAX because bytes is a u16 - bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) }, - session: self.clone(), - kind: packet::Kind::SendPacketPending, - }) + return Err(util::get_last_error()?.into()); } + Ok(packet::Packet { + //SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which + //must be less than isize::MAX because bytes is a u16 + bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) }, + session: self.clone(), + kind: packet::Kind::SendPacketPending, + }) } /// Sends a packet previously allocated with [`Session::allocate_send_packet`] pub fn send_packet(&self, mut packet: packet::Packet) { assert!(matches!(packet.kind, packet::Kind::SendPacketPending)); - unsafe { self.wintun.WintunSendPacket(self.session.0, packet.bytes.as_ptr()) }; + let wintun = self.get_wintun(); + unsafe { wintun.WintunSendPacket(self.session.0, packet.bytes.as_ptr()) }; //Mark the packet at sent packet.kind = packet::Kind::SendPacketSent; } @@ -73,33 +75,34 @@ impl Session { pub fn try_receive(self: &Arc) -> Result, Error> { let mut size = 0u32; - let ptr = unsafe { self.wintun.WintunReceivePacket(self.session.0, &mut size as *mut u32) }; + let wintun = self.get_wintun(); + let ptr = unsafe { wintun.WintunReceivePacket(self.session.0, &mut size as *mut u32) }; debug_assert!(size <= u16::MAX as u32); if ptr.is_null() { //Wintun returns ERROR_NO_MORE_ITEMS instead of blocking if packets are not available - match unsafe { GetLastError() } { + return match unsafe { GetLastError() } { ERROR_NO_MORE_ITEMS => Ok(None), e => Err(std::io::Error::from_raw_os_error(e as i32).into()), - } - } else { - Ok(Some(packet::Packet { - kind: packet::Kind::ReceivePacket, - //SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which - //must be less than isize::MAX because bytes is a u16 - bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) }, - session: self.clone(), - })) + }; } + Ok(Some(packet::Packet { + kind: packet::Kind::ReceivePacket, + //SAFETY: ptr is non null, aligned for u8, and readable for up to size bytes (which + //must be less than isize::MAX because bytes is a u16 + bytes: unsafe { slice::from_raw_parts_mut(ptr, size as usize) }, + session: self.clone(), + })) } /// # Safety /// Returns the low level read event handle that is signaled when more data becomes available /// to read pub unsafe fn get_read_wait_event(&self) -> Result { + let wintun = self.get_wintun(); Ok(self .read_event - .get_or_init(|| UnsafeHandle(self.wintun.WintunGetReadWaitEvent(self.session.0) as _)) + .get_or_init(|| UnsafeHandle(wintun.WintunGetReadWaitEvent(self.session.0) as _)) .0) } @@ -162,7 +165,7 @@ impl Drop for Session { log::error!("Failed to close handle of shutdown event: {:?}", err); } - unsafe { self.wintun.WintunEndSession(self.session.0) }; + unsafe { self.get_wintun().WintunEndSession(self.session.0) }; self.session.0 = ptr::null_mut(); } }