fix: use closures in consume

This commit is contained in:
Ishan Jain 2025-01-21 16:09:41 +05:30
parent 29908b56f4
commit 909966e475
3 changed files with 74 additions and 97 deletions

View File

@ -3,7 +3,7 @@
"maxmind_key": "", "maxmind_key": "",
"//": "refresh every 24h", "//": "refresh every 24h",
"refresh_interval": 86400, "refresh_interval": 86400,
"path": "/tmp/geofw" "path": "/home/ishan/geofw/geofw"
}, },
"interface": "enp6s18", "interface": "enp6s18",
"source_countries": [ "source_countries": [

View File

@ -9,18 +9,18 @@ use aya::{
use flate2::bufread::GzDecoder; use flate2::bufread::GzDecoder;
use fxhash::FxHashSet; use fxhash::FxHashSet;
use geofw_common::ProgramParameters; use geofw_common::ProgramParameters;
use log::{debug, info, warn}; use log::{debug, error, info, warn};
use maxmind::ProcessedDb; use maxmind::{Data, ProcessedDb};
use serde_derive::{Deserialize, Serialize}; use serde_derive::{Deserialize, Serialize};
use std::{ use std::{
fs::File, fs::File,
io::{BufReader, Read}, io::{BufReader, ErrorKind, Read},
path::PathBuf, path::PathBuf,
}; };
use tar::Archive; use tar::Archive;
use tokio::{signal, time}; use tokio::{signal, time};
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Config { pub struct Config {
pub db: Db, pub db: Db,
pub interface: String, pub interface: String,
@ -28,22 +28,52 @@ pub struct Config {
pub source_asn: FxHashSet<u32>, pub source_asn: FxHashSet<u32>,
} }
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] impl Default for Config {
fn default() -> Self {
Self {
db: Default::default(),
interface: "enp1s0".to_string(),
source_countries: Default::default(),
source_asn: Default::default(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Db { pub struct Db {
pub maxmind_key: String, pub maxmind_key: String,
pub refresh_interval: i64, pub refresh_interval: i64,
pub path: String, pub path: String,
} }
impl Default for Db {
fn default() -> Self {
Self {
maxmind_key: "".to_string(),
refresh_interval: 86400,
path: "/tmp/geofw".to_string(),
}
}
}
const COUNTRY_DB: &str = "GeoLite2-Country"; const COUNTRY_DB: &str = "GeoLite2-Country";
const ASN_DB: &str = "GeoLite2-ASN"; const ASN_DB: &str = "GeoLite2-ASN";
fn read_config(path: &str) -> Result<Config, String> { fn read_config(path: &str) -> Result<Config, String> {
let mut f = File::open(path).map_err(|e| e.to_string())?; match File::open(path) {
Ok(mut f) => {
let mut contents = vec![]; let mut contents = vec![];
f.read_to_end(&mut contents).map_err(|e| e.to_string())?; f.read_to_end(&mut contents).map_err(|e| e.to_string())?;
serde_json::from_slice(&contents).map_err(|e| e.to_string()) serde_json::from_slice(&contents).map_err(|e| e.to_string())
}
Err(e) if e.kind() == ErrorKind::NotFound => {
if let Err(e) = File::create(path) {
warn!("error in writing config to {}: {}", path, e);
}
Ok(Default::default())
}
Err(e) => Err(format!("permission denied reading {}: {}", path, e)),
}
} }
fn fetch_geoip_db(config: &Config, db_name: &str) -> Result<ProcessedDb, String> { fn fetch_geoip_db(config: &Config, db_name: &str) -> Result<ProcessedDb, String> {
@ -51,19 +81,15 @@ fn fetch_geoip_db(config: &Config, db_name: &str) -> Result<ProcessedDb, String>
unpack_path.push(&config.db.path); unpack_path.push(&config.db.path);
unpack_path.push(format!("{}.mmdb", db_name)); unpack_path.push(format!("{}.mmdb", db_name));
info!("unpack path = {:?}", unpack_path);
let url = format!("https://download.maxmind.com/app/geoip_download?edition_id={}&license_key={}&suffix=tar.gz", db_name, config.db.maxmind_key); let url = format!("https://download.maxmind.com/app/geoip_download?edition_id={}&license_key={}&suffix=tar.gz", db_name, config.db.maxmind_key);
info!("fetching db from = {}", url); info!("path = {:?} fetching db from = {}", unpack_path, url);
let response = ureq::get(&url).call(); let response = ureq::get(&url).call();
let db = match response { match response {
Ok(v) if v.status() != 200 => { Ok(v) if v.status() != 200 => {
warn!("response from maxmind is not 200 = {}", v.status()); warn!("response from maxmind is not 200 = {}", v.status());
maxmind::MaxmindDB::from_file(&unpack_path.to_string_lossy())?
} }
Ok(resp) => { Ok(resp) => {
let reader = resp.into_reader(); let reader = resp.into_reader();
@ -74,40 +100,53 @@ fn fetch_geoip_db(config: &Config, db_name: &str) -> Result<ProcessedDb, String>
.entries() .entries()
.map_err(|e| format!("error in listing files in the archive: {}", e))?; .map_err(|e| format!("error in listing files in the archive: {}", e))?;
let mut db_entry = entries let db_entry = entries
.into_iter() .into_iter()
.filter_map(|e| e.ok()) .filter_map(|e| e.ok())
.filter_map(|entry| { .filter_map(|entry| {
let path = match entry.path() { let Ok(path) = entry.path() else {
Ok(v) => v, return None;
Err(_) => return None,
}; };
if path.extension().is_none_or(|x| x != "mmdb") { if path.extension().is_none_or(|x| x != "mmdb") {
return None; return None;
} }
Some(entry) Some(entry)
}) })
.next() .next();
.unwrap();
let Some(mut db_entry) = db_entry else {
return Err("error in finding mmdb file in the tarball".to_string());
};
db_entry.unpack(&unpack_path).map_err(|e| e.to_string())?; db_entry.unpack(&unpack_path).map_err(|e| e.to_string())?;
maxmind::MaxmindDB::from_file(&unpack_path.to_string_lossy())?
} }
Err(e) => { Err(e) => {
warn!("error in fetching db from maxmind: {}", e); warn!("error in fetching db from maxmind: {}", e);
maxmind::MaxmindDB::from_file(&unpack_path.to_string_lossy())?
} }
}; };
let db = maxmind::MaxmindDB::from_file(&unpack_path.to_string_lossy())?;
info!("downloaded {}", db_name); info!("downloaded {}", db_name);
match db_name { match db_name {
COUNTRY_DB => Ok(db.consume_country_database(&config.source_countries)), COUNTRY_DB => Ok(db.consume(|data| -> bool {
ASN_DB => Ok(db.consume_asn_database(&config.source_asn)), let Some(Data::Map(country)) = data.get("country".as_bytes()) else {
return false;
};
let Some(iso_code) = country.get("iso_code".as_bytes()) else {
return false;
};
config.source_countries.contains(&iso_code.to_string())
})),
ASN_DB => Ok(db.consume(|data| -> bool {
let Some(Data::U32(asn)) = data.get("autonomous_system_number".as_bytes()) else {
return false;
};
config.source_asn.contains(asn)
})),
_ => Err("unknown db".to_string()), _ => Err("unknown db".to_string()),
} }

View File

@ -51,7 +51,7 @@ impl Display for Data<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self { match self {
Data::String(s) => f.write_str(&String::from_utf8_lossy(s)), Data::String(s) => f.write_str(&String::from_utf8_lossy(s)),
Data::Double(s) => f.write_str(&s.to_string()), Data::Double(s) => write!(f, "{s}"),
Data::Bytes(s) => f.write_fmt(format_args!("{:?}", s)), Data::Bytes(s) => f.write_fmt(format_args!("{:?}", s)),
Data::U16(s) => f.write_str(&s.to_string()), Data::U16(s) => f.write_str(&s.to_string()),
Data::U32(s) => f.write_str(&s.to_string()), Data::U32(s) => f.write_str(&s.to_string()),
@ -168,6 +168,7 @@ impl MaxmindDB {
} }
} }
#[allow(unused)]
pub fn lookup(&self, addr: IpAddr) -> Option<Data> { pub fn lookup(&self, addr: IpAddr) -> Option<Data> {
let node_size = self.metadata.record_size as usize * 2 / 8; let node_size = self.metadata.record_size as usize * 2 / 8;
let mut node = 0; let mut node = 0;
@ -197,7 +198,7 @@ impl MaxmindDB {
} }
} }
pub fn consume_country_database(mut self, countries: &FxHashSet<String>) -> ProcessedDb { pub fn consume(mut self, should_block: impl Fn(FxHashMap<&[u8], Data>) -> bool) -> ProcessedDb {
let mut stack = vec![]; let mut stack = vec![];
let node_size = self.metadata.record_size as usize * 2 / 8; let node_size = self.metadata.record_size as usize * 2 / 8;
stack.push((0, 0)); stack.push((0, 0));
@ -230,71 +231,8 @@ impl MaxmindDB {
let Data::Map(data) = data else { let Data::Map(data) = data else {
unreachable!() unreachable!()
}; };
let Some(Data::Map(country)) = data.get("country".as_bytes()) else {
continue;
};
let Some(iso_code) = country.get("iso_code".as_bytes()) else {
unreachable!()
};
if countries.contains(&iso_code.to_string()) { if should_block(data) {
// Mark this node as non existent
Self::write_over_node_bytes(
&mut self.data
[node as usize * node_size..(node as usize * node_size) + node_size],
0,
self.metadata.record_size,
BLOCK_MARKER,
);
}
}
// Trim database to only contain the binary tree
ProcessedDb {
node_count: self.metadata.node_count,
record_size: self.metadata.record_size,
db: self.data[..self.metadata.data_section_start].to_vec(),
}
}
pub fn consume_asn_database(mut self, asns: &FxHashSet<u32>) -> ProcessedDb {
let mut stack = vec![];
let node_size = self.metadata.record_size as usize * 2 / 8;
stack.push((0, 0));
while let Some((node, position)) = stack.pop() {
let n =
&mut self.data[node as usize * node_size..(node as usize * node_size) + node_size];
let node_1 = Self::node_from_bytes(n, 0, self.metadata.record_size);
let node_2 = Self::node_from_bytes(n, 1, self.metadata.record_size);
if position < 128 && node_1 < self.metadata.node_count {
stack.push((node_1, position + 1));
}
if position < 128 && node_2 < self.metadata.node_count {
stack.push((node_2, position + 1));
}
let data_section_offset = if node_1 != BLOCK_MARKER && node_1 > self.metadata.node_count
{
node_1 - self.metadata.node_count
} else if node_2 != BLOCK_MARKER && node_2 > self.metadata.node_count {
node_2 - self.metadata.node_count
} else {
continue;
};
let (data, _) = self
.read_data(self.metadata.data_section_start + data_section_offset as usize - 16);
let Data::Map(data) = data else {
unreachable!()
};
let Some(Data::U32(asn)) = data.get("autonomous_system_number".as_bytes()) else {
continue;
};
if asns.contains(asn) {
// Mark this node as non existent // Mark this node as non existent
Self::write_over_node_bytes( Self::write_over_node_bytes(
&mut self.data &mut self.data