Skip to content

Commit

Permalink
Rewrite the iterator function logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Jan 18, 2025
1 parent 2606322 commit a0c80dc
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 51 deletions.
88 changes: 56 additions & 32 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,27 @@ impl Adapter {
let mut adapter_addresses = vec![];

util::get_adapters_addresses(|adapter| {
let name_iter = unsafe { util::win_pstr_to_string(adapter.AdapterName) }?;
let name_iter = match unsafe { util::win_pstr_to_string(adapter.AdapterName) } {
Ok(name) => name,
Err(err) => {
log::error!("Failed to parse adapter name: {}", err);
return false;
}
};
if name_iter == name {
let mut current_address = adapter.FirstUnicastAddress;
while !current_address.is_null() {
let address = unsafe { (*current_address).Address };
let address = util::retrieve_ipaddr_from_socket_address(&address);
if let Err(err) = address {
log::error!("Failed to parse address: {}", err);
} else {
adapter_addresses.push(address?);
}
unsafe {
current_address = (*current_address).Next;
match util::retrieve_ipaddr_from_socket_address(&address) {
Ok(addr) => adapter_addresses.push(addr),
Err(err) => {
log::error!("Failed to parse address: {}", err);
}
}
unsafe { current_address = (*current_address).Next };
}
}
Ok(())
true
})?;

Ok(adapter_addresses)
Expand All @@ -339,23 +343,27 @@ impl Adapter {
let name = util::guid_to_win_style_string(&GUID::from_u128(self.guid))?;
let mut gateways = vec![];
util::get_adapters_addresses(|adapter| {
let name_iter = unsafe { util::win_pstr_to_string(adapter.AdapterName) }?;
let name_iter = match unsafe { util::win_pstr_to_string(adapter.AdapterName) } {
Ok(name) => name,
Err(err) => {
log::error!("Failed to parse adapter name: {}", err);
return false;
}
};
if name_iter == name {
let mut current_gateway = adapter.FirstGatewayAddress;
while !current_gateway.is_null() {
let gateway = unsafe { (*current_gateway).Address };
let gateway = util::retrieve_ipaddr_from_socket_address(&gateway);
if let Err(err) = gateway {
log::error!("Failed to parse gateway: {}", err);
} else {
gateways.push(gateway?);
}
unsafe {
current_gateway = (*current_gateway).Next;
match util::retrieve_ipaddr_from_socket_address(&gateway) {
Ok(addr) => gateways.push(addr),
Err(err) => {
log::error!("Failed to parse gateway: {}", err);
}
}
unsafe { current_gateway = (*current_gateway).Next };
}
}
Ok(())
true
})?;
Ok(gateways)
}
Expand All @@ -365,39 +373,55 @@ impl Adapter {
let name = util::guid_to_win_style_string(&GUID::from_u128(self.guid))?;
let mut subnet_mask = None;
util::get_adapters_addresses(|adapter| {
let name_iter = unsafe { util::win_pstr_to_string(adapter.AdapterName) }?;
let name_iter = match unsafe { util::win_pstr_to_string(adapter.AdapterName) } {
Ok(name) => name,
Err(err) => {
log::warn!("Failed to parse adapter name: {}", err);
return false;
}
};
if name_iter == name {
let mut current_address = adapter.FirstUnicastAddress;
while !current_address.is_null() {
let address = unsafe { (*current_address).Address };
let address = util::retrieve_ipaddr_from_socket_address(&address);
if let Err(ref err) = address {
log::warn!("Failed to parse address: {}", err);
}
let address = address?;
let address = match util::retrieve_ipaddr_from_socket_address(&address) {
Ok(addr) => addr,
Err(err) => {
log::warn!("Failed to parse address: {}", err);
return false;
}
};
if address == *target_address {
let masklength = unsafe { (*current_address).OnLinkPrefixLength };
match address {
IpAddr::V4(_) => {
let mut mask = 0_u32;
match unsafe { ConvertLengthToIpv4Mask(masklength as u32, &mut mask as *mut u32) } {
0 => {}
err => return Err(std::io::Error::from_raw_os_error(err as i32).into()),
err => {
log::warn!("Failed to convert length to mask: {}", err);
return false;
}
}
subnet_mask = Some(IpAddr::V4(Ipv4Addr::from(mask.to_le_bytes())));
}
IpAddr::V6(_) => {
subnet_mask = Some(IpAddr::V6(util::ipv6_netmask_for_prefix(masklength)?));
let v = match util::ipv6_netmask_for_prefix(masklength) {
Ok(v) => v,
Err(err) => {
log::warn!("Failed to convert length to mask: {}", err);
return false;
}
};
subnet_mask = Some(IpAddr::V6(v));
}
}
break;
}
unsafe {
current_address = (*current_address).Next;
}
unsafe { current_address = (*current_address).Next };
}
}
Ok(())
true
})?;

Ok(subnet_mask.ok_or("Unable to find matching address")?)
Expand Down
55 changes: 36 additions & 19 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,19 @@ pub fn get_active_network_interface_gateways() -> std::io::Result<Vec<IpAddr>> {
{
let sockaddr_ptr = gateway.Address.lpSockaddr;
let sockaddr = unsafe { &*(sockaddr_ptr as *const SOCKADDR) };
let a = unsafe { sockaddr_to_socket_addr(sockaddr) }?;
let a = match unsafe { sockaddr_to_socket_addr(sockaddr) } {
Ok(a) => a,
Err(e) => {
log::error!("Failed to convert sockaddr to socket address: {}", e);
return false;
}
};
addrs.push(a.ip());
}
current_gateway = gateway.Next;
}
}
Ok(())
true
})?;
Ok(addrs)
}
Expand Down Expand Up @@ -237,7 +243,7 @@ pub(crate) unsafe fn sockaddr_in6_to_socket_addr(sockaddr_in6: &SOCKADDR_IN6) ->

pub(crate) fn get_adapters_addresses<F>(mut callback: F) -> Result<(), Error>
where
F: FnMut(IP_ADAPTER_ADDRESSES_LH) -> Result<(), Error>,
F: FnMut(IP_ADAPTER_ADDRESSES_LH) -> bool,
{
let mut size = 0;
let flags = GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_GATEWAYS;
Expand Down Expand Up @@ -276,7 +282,9 @@ where
let mut current_addresses = addresses.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH;
while !current_addresses.is_null() {
unsafe {
callback(*current_addresses)?;
if !callback(*current_addresses) {
break;
}
current_addresses = (*current_addresses).Next;
}
}
Expand All @@ -285,7 +293,7 @@ where

fn get_interface_info_sys<F>(mut callback: F) -> Result<(), Error>
where
F: FnMut(IP_ADAPTER_INDEX_MAP) -> Result<(), Error>,
F: FnMut(IP_ADAPTER_INDEX_MAP) -> bool,
{
let mut buf_len: u32 = 0;
//First figure out the size of the buffer needed to store the adapter info
Expand Down Expand Up @@ -357,7 +365,9 @@ where
let interfaces = unsafe { std::slice::from_raw_parts(first_adapter, adapter_count as usize) };

for interface in interfaces {
callback(*interface)?;
if !callback(*interface) {
break;
}
}
Ok(())
}
Expand All @@ -366,16 +376,24 @@ where
pub(crate) fn get_interface_info() -> Result<Vec<(u32, String)>, Error> {
let mut v = vec![];
get_interface_info_sys(|mut interface| {
let name = unsafe { win_pwstr_to_string(&mut interface.Name as _)? };
let name = match unsafe { win_pwstr_to_string(&mut interface.Name as _) } {
Ok(name) => name,
Err(e) => {
log::error!("Failed to convert interface name: {}", e);
return false;
}
};
// Nam is something like: \DEVICE\TCPIP_{29C47F55-C7BD-433A-8BF7-408DFD3B3390}
// where the GUID is the {29C4...90}, separated by dashes
let guid = name
.split('{')
.nth(1)
.and_then(|s| s.split('}').next())
.ok_or(format!("Failed to find GUID inside adapter name: {}", name))?;
v.push((interface.Index, guid.to_string()));
Ok(())
let guid = match name.split('{').nth(1).and_then(|s| s.split('}').next()) {
Some(guid) => guid.to_string(),
None => {
log::error!("Failed to extract GUID from interface name: {}", name);
return false;
}
};
v.push((interface.Index, guid));
true
})?;
Ok(v)
}
Expand Down Expand Up @@ -538,7 +556,7 @@ pub(crate) fn get_mtu_by_index(index: u32, is_ipv6: bool) -> std::io::Result<u32
if item.InterfaceIndex == index {
mtu = Some(item.NlMtu);
}
Ok(())
true
},
is_ipv6,
)?;
Expand All @@ -555,7 +573,7 @@ pub fn decode_utf16(string: &[u16]) -> String {

pub fn get_ip_interface_table<F>(mut callback: F, is_ipv6: bool) -> std::io::Result<()>
where
F: FnMut(&MIB_IPINTERFACE_ROW) -> std::io::Result<()>,
F: FnMut(&MIB_IPINTERFACE_ROW) -> bool,
{
let mut if_table: *mut MIB_IPINTERFACE_TABLE = std::ptr::null_mut();
unsafe {
Expand All @@ -568,9 +586,8 @@ where
use std::slice::from_raw_parts;
let ifaces = from_raw_parts::<MIB_IPINTERFACE_ROW>(&(*if_table).Table[0], (*if_table).NumEntries as usize);
for item in ifaces {
if let Err(e) = callback(item) {
FreeMibTable(if_table as _);
return Err(e);
if !callback(item) {
break;
}
}
FreeMibTable(if_table as _);
Expand Down

0 comments on commit a0c80dc

Please sign in to comment.