diff --git a/src/header.rs b/src/header.rs index 4a3faf7..48912d0 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Header { pub ident: u16, pub query: bool, diff --git a/src/main.rs b/src/main.rs index a808662..0da4264 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use std::net::{SocketAddr, UdpSocket}; fn main() { let args: Vec = std::env::args().collect(); - let resolver: SocketAddr = args[2].parse().unwrap(); + let resolver: SocketAddr = args[2].parse().expect("resolver address not provided"); println!("resolver = {:?}", resolver); @@ -19,55 +19,75 @@ fn main() { let udp_socket = UdpSocket::bind("127.0.0.1:2053").expect("Failed to bind to address"); let mut buf = [0; 512]; - 'outer: loop { + loop { match udp_socket.recv_from(&mut buf) { Ok((size, source)) => { println!("Received {} bytes from {}", size, source); let received_data = &buf[0..size]; + let mut recv_packet = Packet::parse(received_data).unwrap(); - println!("{:0x?}", received_data); + for question in recv_packet.questions.iter() { + let mut packet = recv_packet.clone(); + packet.header.qd_count = 1; + packet.questions = vec![question.clone()]; - upstream_socket - .send_to(received_data, resolver) - .expect("error in sending data to upstream"); + let mut data = vec![]; + packet.write_to(&mut data); - let mut lbuf = [0; 512]; + upstream_socket + .send_to(&data, resolver) + .expect("error in sending data to upstream"); + } + recv_packet.header.query = true; + recv_packet.header.authoritative = false; + recv_packet.header.truncated = false; + recv_packet.header.recursion_avail = false; + recv_packet.header.reserved = 0; + recv_packet.header.rcode = if recv_packet.header.opcode == 0 { 0 } else { 4 }; + recv_packet.header.an_count = recv_packet.header.qd_count; + recv_packet.header.authority_records = 0; + recv_packet.header.additional_records = 0; - match upstream_socket.recv_from(&mut lbuf) { - Ok((size, upstream)) => { - println!( - "Received {}bytes from {} on upstream socket", - size, upstream - ); + let mut responses = vec![[0; 512]; recv_packet.header.qd_count as usize]; + let mut upstream_packets = vec![None; recv_packet.header.qd_count as usize]; - let mut recv_packet = Packet::parse(received_data).unwrap(); - let upstream_packet = Packet::parse(&lbuf).unwrap(); + for (lbuf, packet) in responses.iter_mut().zip(upstream_packets.iter_mut()) { + match upstream_socket.recv_from(lbuf) { + Ok((size, upstream)) => { + println!( + "Received {}bytes from {} on upstream socket", + size, upstream + ); - recv_packet.header.query = true; - recv_packet.header.authoritative = false; - recv_packet.header.truncated = false; - recv_packet.header.recursion_avail = false; - recv_packet.header.reserved = 0; - recv_packet.header.rcode = - if recv_packet.header.opcode == 0 { 0 } else { 4 }; - recv_packet.header.an_count = upstream_packet.header.an_count; - recv_packet.header.authority_records = 0; - recv_packet.header.additional_records = 0; - - recv_packet.answers = upstream_packet.answers; - - let mut response = vec![]; - recv_packet.write_to(&mut response); - - udp_socket - .send_to(&response, source) - .expect("Failed to send response"); - } - Err(e) => { - eprintln!("error in receving data from upstream: {}", e); - continue 'outer; + *packet = Packet::parse(lbuf).ok(); + } + Err(e) => { + eprintln!("error in receiving data from upstream: {}", e); + } } } + + recv_packet.answers = upstream_packets + .into_iter() + .filter(|x| x.is_some()) + .flat_map(|packet| { + if let Some(packet) = packet { + if packet.answers.is_empty() { + return None; + } + Some(packet.answers[0].clone()) + } else { + None + } + }) + .collect(); + + let mut response = vec![]; + recv_packet.write_to(&mut response); + + udp_socket + .send_to(&response, source) + .expect("Failed to send response"); } Err(e) => { eprintln!("Error receiving data: {}", e); @@ -77,7 +97,7 @@ fn main() { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Packet<'a> { header: Header, questions: Vec>, diff --git a/src/question.rs b/src/question.rs index f30f039..d6c9097 100644 --- a/src/question.rs +++ b/src/question.rs @@ -1,6 +1,6 @@ use crate::qname::Qname; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Question<'a> { pub name: Qname<'a>, pub q_type: u16, diff --git a/src/rrecord.rs b/src/rrecord.rs index 4ec82da..ca7d4a2 100644 --- a/src/rrecord.rs +++ b/src/rrecord.rs @@ -1,6 +1,6 @@ use crate::qname::Qname; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RRecord<'a> { pub name: Qname<'a>, pub r_type: u16, @@ -60,17 +60,15 @@ impl<'a> RRecord<'a> { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum RData { A([u8; 4]), - Aaaa([u8; 16]), } impl RData { pub fn write_to(self, buf: &mut Vec) { match self { RData::A(addr) => buf.extend(addr), - RData::Aaaa(addr) => buf.extend(addr), } } }