From 5531c2eb9376db098ba503ee78c09819979ca997 Mon Sep 17 00:00:00 2001 From: Ishan Jain Date: Wed, 22 Jan 2025 08:07:12 +0530 Subject: [PATCH] fix: fixed a bug in the binary tree traversal, refactored other parts of the code --- geofw-common/src/lib.rs | 10 +-- geofw-ebpf/src/main.rs | 55 ++++++++-------- geofw/src/main.rs | 54 ++++++++-------- geofw/src/maxmind.rs | 139 ++++++++++++++++++++++------------------ 4 files changed, 136 insertions(+), 122 deletions(-) diff --git a/geofw-common/src/lib.rs b/geofw-common/src/lib.rs index 7850d10..1dd51ea 100644 --- a/geofw-common/src/lib.rs +++ b/geofw-common/src/lib.rs @@ -17,20 +17,20 @@ pub enum ProgramParameters { pub const BLOCK_MARKER: u32 = 0x00ffffff; #[derive(Copy, Clone)] -pub enum MaxmindDb { +pub enum MaxmindDbType { Country, Asn, } -impl Display for MaxmindDb { +impl Display for MaxmindDbType { fn fmt(&self, f: &mut Formatter) -> FmtResult { let val = match self { - MaxmindDb::Country => "GeoLite2-Country", - MaxmindDb::Asn => "GeoLite2-ASN", + MaxmindDbType::Country => "GeoLite2-Country", + MaxmindDbType::Asn => "GeoLite2-ASN", }; write!(f, "{val}") } } -// impl aya_log_ebpf::macro_support::Formatter for MaxmindDb {} + diff --git a/geofw-ebpf/src/main.rs b/geofw-ebpf/src/main.rs index fef98d0..919fc36 100644 --- a/geofw-ebpf/src/main.rs +++ b/geofw-ebpf/src/main.rs @@ -9,7 +9,7 @@ use aya_ebpf::{ }; use aya_log_ebpf::{debug, warn}; use core::{mem, net::IpAddr}; -use geofw_common::{MaxmindDb, ProgramParameters, BLOCK_MARKER}; +use geofw_common::{MaxmindDbType, ProgramParameters, BLOCK_MARKER}; use network_types::{ eth::{EthHdr, EtherType}, ip::{Ipv4Hdr, Ipv6Hdr}, @@ -60,10 +60,10 @@ fn filter_ip_packet(ctx: XdpContext) -> Result { let ip: *const Ipv4Hdr = ptr_at(&ctx, EthHdr::LEN).ok_or(xdp_action::XDP_PASS)?; let source = unsafe { (*ip).src_addr() }; - let result = should_block(&ctx, MaxmindDb::Asn, &BLOCKED_ASN, IpAddr::V4(source)) + let result = should_block(&ctx, MaxmindDbType::Asn, &BLOCKED_ASN, IpAddr::V4(source)) || should_block( &ctx, - MaxmindDb::Country, + MaxmindDbType::Country, &BLOCKED_COUNTRY, IpAddr::V4(source), ); @@ -82,10 +82,10 @@ fn filter_ipv6_packet(ctx: XdpContext) -> Result { let ip: *const Ipv6Hdr = ptr_at(&ctx, EthHdr::LEN).ok_or(xdp_action::XDP_PASS)?; let source = unsafe { (*ip).src_addr() }; - let result = should_block(&ctx, MaxmindDb::Asn, &BLOCKED_ASN, IpAddr::V6(source)) + let result = should_block(&ctx, MaxmindDbType::Asn, &BLOCKED_ASN, IpAddr::V6(source)) || should_block( &ctx, - MaxmindDb::Country, + MaxmindDbType::Country, &BLOCKED_COUNTRY, IpAddr::V6(source), ); @@ -101,22 +101,27 @@ fn filter_ipv6_packet(ctx: XdpContext) -> Result { } } -pub fn should_block(ctx: &XdpContext, db_name: MaxmindDb, map: &Array, addr: IpAddr) -> bool { - let record_size = match db_name { - MaxmindDb::Country => unsafe { +pub fn should_block( + ctx: &XdpContext, + db_type: MaxmindDbType, + map: &Array, + addr: IpAddr, +) -> bool { + let record_size = match db_type { + MaxmindDbType::Country => unsafe { PARAMETERS.get(&(ProgramParameters::CountryRecordSize as u8)) }, - MaxmindDb::Asn => unsafe { PARAMETERS.get(&(ProgramParameters::AsnRecordSize as u8)) }, + MaxmindDbType::Asn => unsafe { PARAMETERS.get(&(ProgramParameters::AsnRecordSize as u8)) }, }; let Some(&record_size) = record_size else { return false; }; - let node_count = match db_name { - MaxmindDb::Country => unsafe { + let node_count = match db_type { + MaxmindDbType::Country => unsafe { PARAMETERS.get(&(ProgramParameters::CountryNodeCount as u8)) }, - MaxmindDb::Asn => unsafe { PARAMETERS.get(&(ProgramParameters::AsnNodeCount as u8)) }, + MaxmindDbType::Asn => unsafe { PARAMETERS.get(&(ProgramParameters::AsnNodeCount as u8)) }, }; let Some(&node_count) = node_count else { return false; @@ -135,7 +140,7 @@ pub fn should_block(ctx: &XdpContext, db_name: MaxmindDb, map: &Array, addr: }; while i >= 0 && node < node_count { - let bit = (ip & (1 << 127)) == 0; + let left = (ip & (1 << 127)) == 0; ip <<= 1; let mut slice = [0; 8]; @@ -152,31 +157,21 @@ pub fn should_block(ctx: &XdpContext, db_name: MaxmindDb, map: &Array, addr: } } } - node = node_from_bytes(slice, bit, record_size as u16); + node = node_from_bytes(slice, left, record_size as u16); i -= 1; } node == BLOCK_MARKER } -fn node_from_bytes(n: [u8; 8], bit: bool, record_size: u16) -> u32 { +fn node_from_bytes(n: [u8; 8], left: bool, record_size: u16) -> u32 { match record_size { - 28 => { - if bit { - u32::from_be_bytes([(n[3] & 0b1111_0000) >> 4, n[0], n[1], n[2]]) - } else { - u32::from_be_bytes([n[3] & 0b0000_1111, n[4], n[5], n[6]]) - } - } - 24 => { - if bit { - u32::from_be_bytes([0, n[0], n[1], n[2]]) - } else { - u32::from_be_bytes([0, n[3], n[4], n[5]]) - } - } + 28 if left => u32::from_be_bytes([(n[3] & 0b1111_0000) >> 4, n[0], n[1], n[2]]), + 28 => u32::from_be_bytes([n[3] & 0b0000_1111, n[4], n[5], n[6]]), + 24 if left => u32::from_be_bytes([0, n[0], n[1], n[2]]), + 24 => u32::from_be_bytes([0, n[3], n[4], n[5]]), - // this should never reach + // This should never run unless we are using 32bit dbs _ => 0, } } diff --git a/geofw/src/main.rs b/geofw/src/main.rs index 702a566..41df923 100644 --- a/geofw/src/main.rs +++ b/geofw/src/main.rs @@ -8,7 +8,7 @@ use aya::{ }; use flate2::bufread::GzDecoder; use fxhash::FxHashSet; -use geofw_common::{MaxmindDb, ProgramParameters}; +use geofw_common::{MaxmindDbType, ProgramParameters}; use log::{debug, info, warn}; use maxmind::{Data, ProcessedDb}; use serde_derive::{Deserialize, Serialize}; @@ -16,6 +16,7 @@ use std::{ fs::File, io::{BufReader, ErrorKind, Read, Write}, path::PathBuf, + time::Instant, }; use tar::Archive; use tokio::{signal, time}; @@ -81,12 +82,12 @@ fn read_config(path: &str) -> Result { } } -fn fetch_geoip_db(config: &Config, db_name: MaxmindDb) -> Result { +fn fetch_geoip_db(config: &Config, db_type: MaxmindDbType) -> Result { let mut unpack_path = PathBuf::new(); unpack_path.push(&config.db.path); - unpack_path.push(format!("{}.mmdb", db_name)); + unpack_path.push(format!("{}.mmdb", db_type)); - let url = format!("https://download.maxmind.com/app/geoip_download?edition_id={}&license_key={}&suffix=tar.gz", db_name, config.db.maxmind_key); + let url = format!("https://download.maxmind.com/app/geoip_download?edition_id={}&license_key={}&suffix=tar.gz", db_type, config.db.maxmind_key); info!("path = {:?} fetching db from = {}", unpack_path, url); @@ -130,10 +131,10 @@ fn fetch_geoip_db(config: &Config, db_name: MaxmindDb) -> Result Ok(db.consume(|data| -> bool { + match db_type { + MaxmindDbType::Country => Ok(db.consume(|data| -> bool { let Some(Data::Map(country)) = data.get("country".as_bytes()) else { return false; }; @@ -143,7 +144,7 @@ fn fetch_geoip_db(config: &Config, db_name: MaxmindDb) -> Result Ok(db.consume(|data| -> bool { + MaxmindDbType::Asn => Ok(db.consume(|data| -> bool { let Some(Data::U32(asn)) = data.get("autonomous_system_number".as_bytes()) else { return false; }; @@ -194,17 +195,17 @@ async fn main() -> anyhow::Result<()> { _ = interval.tick() => { info!("updating DB"); - match update_geoip_map(&config, &mut ebpf, MaxmindDb::Country, "BLOCKED_COUNTRY") { + match update_geoip_map(&config, &mut ebpf, MaxmindDbType::Country, "BLOCKED_COUNTRY") { Ok(_) => (), Err(e) => { - warn!("error in updating map {} = {}", MaxmindDb::Country, e); + warn!("error in updating map {} = {}", MaxmindDbType::Country, e); } } - match update_geoip_map(&config, &mut ebpf, MaxmindDb::Asn, "BLOCKED_ASN") { + match update_geoip_map(&config, &mut ebpf, MaxmindDbType::Asn, "BLOCKED_ASN") { Ok(_) => (), Err(e) => { - warn!("error in updating map {} = {}", MaxmindDb::Asn, e); + warn!("error in updating map {} = {}", MaxmindDbType::Asn, e); } } } @@ -217,35 +218,38 @@ async fn main() -> anyhow::Result<()> { fn update_geoip_map( config: &Config, ebpf: &mut Ebpf, - db_name: MaxmindDb, + db_type: MaxmindDbType, map_name: &str, ) -> Result<(), String> { - info!("updating maps db_name = {db_name} map_name = {map_name}"); + info!("updating maps db_type = {db_type} map_name = {map_name}"); let mut map = Array::try_from(ebpf.map_mut(map_name).expect("error in getting map")) .expect("error in processing map"); - let result = fetch_geoip_db(config, db_name)?; - - info!( - "set map = {map_name} up to the location = {} record_size = {} node_count = {}", - result.db.len(), - result.record_size, - result.node_count - ); + let result = fetch_geoip_db(config, db_type)?; + let t = Instant::now(); for (i, v) in result.db.into_iter().enumerate() { map.set(i as u32, v, 0).map_err(|e| e.to_string())?; } + info!( + "updated map = {} record_size = {} node_count = {} est_size = {} time_taken = {:?}", + map_name, + result.record_size, + result.node_count, + result.record_size as u64 * result.node_count as u64, + t.elapsed() + ); + let mut map: HashMap<&mut MapData, u8, u32> = HashMap::try_from( ebpf.map_mut("PARAMETERS") .expect("error in getting parameter map"), ) .expect("error in processing parameter map"); - match db_name { - MaxmindDb::Country => { + match db_type { + MaxmindDbType::Country => { map.insert( ProgramParameters::CountryNodeCount as u8, result.node_count, @@ -259,7 +263,7 @@ fn update_geoip_map( ) .expect("error in writing country record size to map"); } - MaxmindDb::Asn => { + MaxmindDbType::Asn => { map.insert(ProgramParameters::AsnNodeCount as u8, result.node_count, 0) .expect("error in writing country node count to map"); map.insert( diff --git a/geofw/src/maxmind.rs b/geofw/src/maxmind.rs index caa5271..988f926 100644 --- a/geofw/src/maxmind.rs +++ b/geofw/src/maxmind.rs @@ -2,6 +2,7 @@ use core::str; use fxhash::FxHashMap; use geofw_common::BLOCK_MARKER; use std::{ + collections::VecDeque, fmt::{Debug, Display, Formatter, Result as FmtResult}, fs::File, io::Read, @@ -12,7 +13,7 @@ const METADATA_SECTION_START: &[u8] = &[ 0xab, 0xcd, 0xef, 0x4d, 0x61, 0x78, 0x4d, 0x69, 0x6e, 0x64, 0x2e, 0x63, 0x6f, 0x6d, ]; -pub struct MaxmindDB { +pub struct MaxmindDb { pub metadata: Metadata, pub data: Vec, } @@ -81,13 +82,13 @@ impl Display for Data<'_> { } } -impl Debug for MaxmindDB { +impl Debug for MaxmindDb { fn fmt(&self, f: &mut Formatter) -> FmtResult { f.write_fmt(format_args!("{:?}", self.metadata)) } } -impl MaxmindDB { +impl MaxmindDb { pub fn from_file(path: &str) -> Result { let mut data = vec![]; let mut file = File::open(path).map_err(|e| format!("error in opening file: {}", e))?; @@ -131,40 +132,30 @@ impl MaxmindDB { map } - fn node_from_bytes(n: &[u8], bit: bool, record_size: u16) -> u32 { + fn node_from_bytes(n: &[u8], left: bool, record_size: u16) -> u32 { match record_size { - 28 => { - if bit { - u32::from_be_bytes([(n[3] & 0b1111_0000) >> 4, n[0], n[1], n[2]]) - } else { - u32::from_be_bytes([n[3] & 0b0000_1111, n[4], n[5], n[6]]) - } - } - 24 => { - if bit { - u32::from_be_bytes([0, n[0], n[1], n[2]]) - } else { - u32::from_be_bytes([0, n[3], n[4], n[5]]) - } - } + 28 if left => u32::from_be_bytes([(n[3] & 0b1111_0000) >> 4, n[0], n[1], n[2]]), + 28 => u32::from_be_bytes([n[3] & 0b0000_1111, n[4], n[5], n[6]]), + 24 if left => u32::from_be_bytes([0, n[0], n[1], n[2]]), + 24 => u32::from_be_bytes([0, n[3], n[4], n[5]]), _ => unreachable!(), } } - fn write_over_node_bytes(n: &mut [u8], bit: u128, record_size: u16, val: u32) { + fn write_over_node_bytes(n: &mut [u8], left: bool, record_size: u16, val: u32) { let val = val.to_be_bytes(); match record_size { - 28 if bit == 0 => { + 28 if left => { n[0..=2].copy_from_slice(&val[1..=3]); n[3] = (n[3] & 0b0000_1111) | (val[0] << 4); } - 28 if bit == 1 => { + 28 => { n[4..=6].copy_from_slice(&val[1..=3]); n[3] = (n[3] & 0b1111_0000) | (val[0] & 0b0000_1111); } - 24 if bit == 0 => n[0..=2].copy_from_slice(&val[1..=3]), - 24 if bit == 1 => n[3..=5].copy_from_slice(&val[1..=3]), + 24 if left => n[0..=2].copy_from_slice(&val[1..=3]), + 24 => n[3..=5].copy_from_slice(&val[1..=3]), _ => unreachable!(), } } @@ -189,10 +180,10 @@ impl MaxmindDB { }; while i >= 0 && node < self.metadata.node_count { - let bit = (ip & (1 << i)) == 0; + let left = (ip & (1 << i)) == 0; let n = &self.data[node as usize * node_size..(node as usize * node_size) + node_size]; - node = Self::node_from_bytes(n, bit, self.metadata.record_size); + node = Self::node_from_bytes(n, left, self.metadata.record_size); i -= 1; } @@ -207,50 +198,50 @@ impl MaxmindDB { } } - pub fn consume(mut self, should_block: impl Fn(FxHashMap<&[u8], Data>) -> bool) -> ProcessedDb { - let mut stack = vec![]; + pub fn consume( + mut self, + should_block: impl Fn(&FxHashMap<&[u8], Data>) -> bool, + ) -> ProcessedDb { + let mut stack = VecDeque::new(); let node_size = self.metadata.record_size as usize * 2 / 8; - stack.push((0, 0)); + stack.push_back((0, 0, false)); + + while let Some((node, parent, bit)) = stack.pop_front() { + if node == BLOCK_MARKER { + continue; + } + if node >= self.metadata.node_count { + let ds_offset = node - self.metadata.node_count; + + let (data, _) = + self.read_data(self.metadata.data_section_start + ds_offset as usize - 16); + + let Data::Map(data) = data else { + unreachable!() + }; + if should_block(&data) { + // Mark the parent of this node as non existent + let node = parent; + + Self::write_over_node_bytes( + &mut self.data + [node as usize * node_size..(node as usize * node_size) + node_size], + bit, + self.metadata.record_size, + BLOCK_MARKER, + ); + } + + continue; + } - while let Some((node, position)) = stack.pop() { let n = &mut self.data[node as usize * node_size..(node as usize * node_size) + node_size]; let node_1 = Self::node_from_bytes(n, false, self.metadata.record_size); let node_2 = Self::node_from_bytes(n, true, self.metadata.record_size); - if position < 128 && node_1 < self.metadata.node_count { - stack.push((node_1, position + 1)); - } - if position < 128 && node_2 < self.metadata.node_count { - stack.push((node_2, position + 1)); - } - - let data_section_offset = if node_1 != BLOCK_MARKER && node_1 > self.metadata.node_count - { - node_1 - self.metadata.node_count - } else if node_2 != BLOCK_MARKER && node_2 > self.metadata.node_count { - node_2 - self.metadata.node_count - } else { - continue; - }; - - let (data, _) = self - .read_data(self.metadata.data_section_start + data_section_offset as usize - 16); - - let Data::Map(data) = data else { - unreachable!() - }; - - if should_block(data) { - // Mark this node as non existent - Self::write_over_node_bytes( - &mut self.data - [node as usize * node_size..(node as usize * node_size) + node_size], - 0, - self.metadata.record_size, - BLOCK_MARKER, - ); - } + stack.push_back((node_1, node, false)); + stack.push_back((node_2, node, true)); } // Trim database to only contain the binary tree @@ -443,3 +434,27 @@ impl MaxmindDB { (data_type, length, read + r) } } + +impl ProcessedDb { + pub fn lookup(&self, addr: IpAddr) -> bool { + let node_size = self.record_size as usize * 2 / 8; + let mut node = 0; + + let mut ip = match addr { + IpAddr::V4(a) => a.to_bits() as u128, + IpAddr::V6(a) => a.to_bits(), + }; + + let mut i = 0; + while i < 128 && node < self.node_count { + let left = (ip & (1 << 127)) == 0; + ip <<= 1; + + let n = &self.db[node as usize * node_size..(node as usize * node_size) + node_size]; + node = MaxmindDb::node_from_bytes(n, left, self.record_size); + i += 1; + } + + node == BLOCK_MARKER + } +}