diff --git a/src/main.rs b/src/main.rs index ff5ab21..eeea4e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,49 +6,66 @@ use question::*; mod rrecord; use rrecord::*; -use std::net::UdpSocket; +use std::net::{SocketAddr, UdpSocket}; fn main() { + let args: Vec = std::env::args().collect(); + let resolver: SocketAddr = args[2].parse().unwrap(); + + println!("resolver = {:?}", resolver); + + let upstream_socket = + UdpSocket::bind("0.0.0.0:31000").expect("Failed to bind to upstream address"); let udp_socket = UdpSocket::bind("127.0.0.1:2053").expect("Failed to bind to address"); let mut buf = [0; 512]; - loop { + 'outer: loop { match udp_socket.recv_from(&mut buf) { Ok((size, source)) => { - let received_data = &buf[0..size]; println!("Received {} bytes from {}", size, source); - let mut response = vec![]; + let received_data = &buf[0..size]; - let mut recv_packet = Packet::parse(received_data).unwrap(); + upstream_socket + .send_to(received_data, resolver) + .expect("error in sending data to upstream"); - println!("{:?}", recv_packet); + let mut lbuf = [0; 512]; - recv_packet.header.query = true; - recv_packet.header.authoritative = false; - recv_packet.header.truncated = false; - recv_packet.header.recursion_avail = false; - recv_packet.header.reserved = 0; - recv_packet.header.rcode = if recv_packet.header.opcode == 0 { 0 } else { 4 }; - recv_packet.header.an_count = recv_packet.header.qd_count; - recv_packet.header.authority_records = 0; - recv_packet.header.additional_records = 0; + match upstream_socket.recv_from(&mut lbuf) { + Ok((size, upstream)) => { + println!( + "Received {}bytes from {} on upstream socket", + size, upstream + ); - for question in recv_packet.questions.iter() { - recv_packet.answers.push(RRecord { - name: question.name.clone(), - r_type: 1, - class: 1, - ttl: 1337, - rdlength: 4, - rdata: RData::A([0x8, 0x8, 0x8, 0x8]), - }) + let mut recv_packet = Packet::parse(received_data).unwrap(); + let upstream_packet = Packet::parse(&lbuf).unwrap(); + + recv_packet.header.query = true; + recv_packet.header.authoritative = false; + recv_packet.header.truncated = false; + recv_packet.header.recursion_avail = false; + recv_packet.header.reserved = 0; + recv_packet.header.rcode = + if recv_packet.header.opcode == 0 { 0 } else { 4 }; + recv_packet.header.an_count = recv_packet.header.qd_count; + recv_packet.header.authority_records = 0; + recv_packet.header.additional_records = 0; + + recv_packet.answers = upstream_packet.answers; + + let mut response = vec![]; + recv_packet.write_to(&mut response); + + udp_socket + .send_to(&response, source) + .expect("Failed to send response"); + } + Err(e) => { + eprintln!("error in receving data from upstream: {}", e); + continue 'outer; + } } - - recv_packet.write_to(&mut response); - - udp_socket - .send_to(&response, source) - .expect("Failed to send response"); } Err(e) => { eprintln!("Error receiving data: {}", e); @@ -75,9 +92,6 @@ impl<'a> Packet<'a> { for _ in 0..header.qd_count { let rec = Question::parse(data, original)?; - - println!("label = {}", rec.name); - data = &data[rec.length()..]; questions.push(rec); } @@ -126,10 +140,16 @@ mod test { 0x7, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0xc0, 0xc, 0x00, 0x01, 0x00, 0x01, ]; - let packet = Packet::parse(&data); + let packet = match Packet::parse(&data) { + Ok(v) => v, + Err(e) => { + panic!("failed to parse packet: {}", e); + } + }; - assert!(packet.is_ok()); + let mut buf = Vec::with_capacity(data.len()); + packet.write_to(&mut buf); - println!("{:?}", packet); + assert_eq!(buf, data); } } diff --git a/src/qname.rs b/src/qname.rs index ab7d2fe..183856f 100644 --- a/src/qname.rs +++ b/src/qname.rs @@ -54,31 +54,6 @@ impl<'a> Qname<'a> { pub fn length(&self) -> usize { self.length } - - pub fn labels(&'a self) -> Vec<&'a [u8]> { - let mut lookup: &[u8] = &self.name; - - let mut out = vec![]; - - let mut i = 0; - while i < lookup.len() - 1 && lookup[i + 1] != 0 { - let is_pointer = (lookup[i] & 0xc0) == 0xc0; - - if is_pointer { - let offset: u16 = (((lookup[i] & !0xc0) as u16) << 8) | lookup[i + 1] as u16; - - i = offset as usize; - lookup = &self.original; - } else { - let length = lookup[i] as usize; - - out.push(&lookup[i + 1..i + length + 1]); - i += length + 1; - } - } - - out - } } impl<'a> Display for Qname<'a> {