/*-
* cdns-rs - a simple sync/async DNS query library
* Copyright (C) 2020  Aleksandr Morozov, RELKOM s.r.o
* Copyright (C) 2021-2022  Aleksandr Morozov
* 
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* Lesser General Public License for more details.
* 
* You should have received a copy of the GNU Lesser General Public License
* along with this program; if not, write to the Free Software Foundation,
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
*/


use std::collections::{LinkedList, HashSet};
use std::convert::{TryFrom};
use std::net::{IpAddr};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;

use tokio::task::JoinHandle;

use async_recursion::async_recursion;

use crate::cfg_resolv_parser::{ResolveConfigFamily, ResolveConfEntry};
use crate::error::*;
use crate::query::{QDnsQuery, QDnsQueriesRes, QDnsQueryRec};
use crate::sync::QuerySetup;
use crate::{write_error, internal_error, internal_error_map};
use crate::query_private::QDnsReq;

use super::common::*;
use super::caches::CACHE;
use super::network::{new_udp, new_tcp, NetworkTapType};
use super::query_async_taps::{AsyncTaps, Tap};
use super::{ResolveConfig};


enum SpawnFutereRes
{
    Ok(QDnsQuery),
    Truncated(DnsRequestHeader),
}


/// A main instance which contains all common logic.
pub struct QDns<'req>
{
    /// An instance of the parser /etc/resolv.conf or custom config
    resolvers: Arc<ResolveConfig>,
    /// A pre-ordered list of the requests, if more than one
    ordered_req_list: Vec<QDnsReq<'req>>,
    /// Override options
    opts: QuerySetup,
}

impl<'req> QDns<'req>
{
    /// Initializes new empty storage for requests.
    /// 
    /// In some cases it is good idea to combine different requests, because by default
    /// all requests are performed in parallel. But in some cases it is bad idea.
    /// 
    /// # Arguments
    /// 
    /// * `resolvers` - an [Arc] [ResolveConfig] which contains configuration i.e nameservers
    /// 
    /// * `planned_reqs_len` - how many requests are planned
    /// 
    /// * `opts` - [QuerySetup] additional options or overrides. Use default() for default
    ///     values.
    /// 
    /// # Returns
    /// 
    /// Never panics. Returns Self.
    pub 
    fn make_empty(resolvers: Arc<ResolveConfig>, planned_reqs_len: usize, opts: QuerySetup) -> QDns<'req>
    {
        return 
            Self
            {
                resolvers: resolvers,
                ordered_req_list: Vec::with_capacity(planned_reqs_len),
                opts: opts,
            };
    }

    /// Adds new request to previously created empty storage for request with [QDns::make_empty].
    /// 
    /// # Arguemnts
    /// 
    /// * `qtype` - a [QType] type of the request
    /// 
    /// * `req_name` - a [Into] [QDnsName] which is target. i.e 'localhost' or 'domain.tld'
    /// 
    /// # Returns
    /// 
    /// * [CDnsResult] - Ok with nothing in inner type
    /// 
    /// * [CDnsResult] - Err with error description
    pub 
    fn add_request<R>(&mut self, qtype: QType, req_name: R)
    where R: Into<QDnsName<'req>>
    {
        let qr = QDnsReq::new(req_name.into(), qtype);

        self.ordered_req_list.push(qr);

        return;
    }

    /// This is helper which makes for you an A, AAAA query. The order of A and AAAA and
    /// which from both are allowed is defined in the [ResolveConfig].
    /// 
    /// Use this function directly. Do not use [QDns::make_empty]
    /// 
    /// # Arguments
    /// 
    /// * `resolvers` - an [Arc] [ResolveConfig] which contains configuration i.e nameservers
    /// 
    /// * `req_name` - a [Into] [QDnsName] which is target i.e 'localhost' or 'domain.tld'
    /// 
    /// * `opts` - [QuerySetup] additional options or overrides. Use default() for default
    ///     values.
    /// 
    /// # Returns
    /// 
    /// * [CDnsResult] - Ok with Self as inner type
    /// 
    /// * [CDnsResult] - Err with error description
    pub 
    fn make_a_aaaa_request<R>(resolvers: Arc<ResolveConfig>, req_name: R, opts: QuerySetup) -> QDns<'req>
    where R: Into<QDnsName<'req>>
    {
        // store the A and AAAA depending on order
        let reqs: Vec<QDnsReq<'req>> = 
            match resolvers.family
            {
                ResolveConfigFamily::INET4_INET6 => 
                {
                    let req_n: QDnsName = req_name.into();

                    vec![
                        QDnsReq::<'req>::new(req_n.clone(), QType::A),
                        QDnsReq::<'req>::new(req_n, QType::AAAA),
                    ]
                },
                ResolveConfigFamily::INET6_INET4 => 
                {
                    let req_n: QDnsName = req_name.into();

                    vec![
                        QDnsReq::<'req>::new(req_n.clone(), QType::AAAA),
                        QDnsReq::<'req>::new(req_n, QType::A),
                    ]
                },
                ResolveConfigFamily::INET6 => 
                {
                    vec![
                        QDnsReq::<'req>::new(req_name.into(), QType::AAAA),
                    ]
                },
                ResolveConfigFamily::INET4 => 
                {
                    vec![
                        QDnsReq::<'req>::new(req_name.into(), QType::A),
                    ]
                }
                _ =>
                {
                    // set default
                    let req_n: QDnsName<'req> = req_name.into();

                    vec![
                        QDnsReq::<'req>::new(req_n.clone(), QType::A),
                        QDnsReq::<'req>::new(req_n, QType::AAAA),
                    ]
                }
            };

        
        
        return
            Self
            {
                resolvers: resolvers,
                ordered_req_list: reqs,
                opts: opts,
            };
    }

    /// Runs the created query/ies
    /// 
    /// # Returns
    /// 
    /// * [CDnsResult] with [QDnsQueriesRes] which may contain results
    /// 
    /// * [CDnsResult] with error
    /// 
    /// Should not panic. MT-safe.
    pub async 
    fn query(mut self) -> QDnsQueriesRes
    {
        // check if we need to measure time
        let now = 
            if self.opts.measure_time == true
            {
                Some(Instant::now())
            }
            else
            {
                None
            };

        let mut qres: QDnsQueriesRes = QDnsQueriesRes::DnsNotAvailable;

        // determine where to look firstly i.e file -> bind, bind -> file
        // or bind only, or file only
        if self.resolvers.lookup.is_file_first()
        {
            match self.lookup_file(now.clone()).await
            {
                Ok(r) =>
                    qres.extend(r),
                Err(e) =>
                    write_error!("{}", e),
            }

            // if something left unresolved, try ask internet
            if self.ordered_req_list.is_empty() == false && self.resolvers.lookup.is_bind() == true
            {
                match self.process_request(now.clone()).await
                {
                    Ok(r) => 
                        qres.extend(r),
                    Err(e) =>
                        write_error!("{}", e),
                }
            }
        }
        else
        {
            match self.process_request(now.clone()).await
            {
                Ok(r) => 
                    qres.extend(r),
                Err(e) =>
                    write_error!("{}", e),
            }


            if self.ordered_req_list.is_empty() == false && self.resolvers.lookup.is_file() == true
            {
                match self.lookup_file(now.clone()).await
                {
                    Ok(r) => 
                        qres.extend(r),
                    Err(e) =>
                        write_error!("{}", e),
                }
            }
        }

        return qres;
    }

    /// Returns timeout
    fn get_timeout(&self) -> Duration
    {
        if let Some(timeout) = self.opts.timeout
        {
            return Duration::from_secs(timeout as u64);
        }
        else
        {
            return Duration::from_secs(self.resolvers.timeout as u64);
        }
    }

    /// Searches in /etc/hosts
    async 
    fn lookup_file(&mut self, now: Option<Instant>) -> CDnsResult<QDnsQueriesRes>
    {
        // check if the it is overriden
        if self.opts.ign_hosts == false
        {
            let hlist = CACHE.clone_host_list().await?;

            let mut dnsquries: LinkedList<QDnsQuery> = LinkedList::new();

            self.ordered_req_list.retain(|req| 
                {
                    match *req.get_type()
                    {
                        QType::A | QType::AAAA => 
                        {
                            let req_name = String::from(req.get_req_name());
    
                            if let Some(res) = hlist.search_by_fqdn(req.get_type(), req_name.as_str())
                            {
                                // create storage for response
                                let drp = 
                                    match DnsResponsePayload::new_local(*req.get_type(), res)
                                    {
                                        Ok(r) => r,
                                        Err(e) =>
                                        {
                                            write_error!("{}", e);

                                            return true;
                                        }
                                    };

                                // store to list
                                dnsquries.push_back(
                                    QDnsQuery::from_local(drp, now.as_ref())
                                );

                                return false;
                            }
                            else
                            {
                                return true;
                            }
                        },
                        QType::PTR => 
                        {
                            let ip: IpAddr = 
                                match IpAddr::try_from(req.get_req_name())
                                {
                                    Ok(r) => r,
                                    Err(e) =>
                                    {
                                        // just skip
                                        write_error!("{}", e);

                                        return true;
                                    }
                                };
    
                            if let Some(r) = hlist.search_by_ip(&ip)
                            {   
                                // create storage for response
                                let drp = 
                                    match DnsResponsePayload::new_local(QType::PTR, r)
                                    {
                                        Ok(r) => r,
                                        Err(e) =>
                                        {
                                            write_error!("{}", e);

                                            return true;
                                        }
                                    };

                                // store to list
                                dnsquries.push_back(
                                    QDnsQuery::from_local(drp, now.as_ref())
                                ); 

                                return false;
                            }
                            else
                            {
                                return true;
                            }
                        },
                        _ => 
                        {
                            // just skip, don't throw error
                            return true;
                        }
                    }
                }
            );

            /*for req in self.ordered_req_list.iter()
            {   
            }*/

            return Ok( QDnsQueriesRes::from(dnsquries) );
        }
        else
        {
            return Ok(QDnsQueriesRes::DnsNotAvailable);
        }
    }

    /// Creates socket based on config and flag. If `force_tcp` is set then Tcp tap will
    /// be created. 
    fn create_socket(
        &self,
        force_tcp: bool, 
        resolver: &ResolveConfEntry, 
        timeout: Duration
    ) -> CDnsResult<Box<NetworkTapType>>
    {
        // create socket, if `force_tcp` is set then the attempt to switch from UDP to TCP
        if self.resolvers.option_flags.is_force_tcp() == true || force_tcp == true
        {
            return new_tcp(resolver.get_resolver_ip(), 53, resolver.get_adapter_ip(), timeout);
        }
        else
        {
            return new_udp(resolver.get_resolver_ip(), 53, resolver.get_adapter_ip(), timeout);
        };
    }

    /// Creates the sockets based on the config i.e if socket reopen is needed. 
    fn create_sockets(
        &self,
        resolver: &ResolveConfEntry, 
        requery_list: Option<LinkedList<DnsRequestHeader>>,
        timeout: Duration,
        force_tcp: bool,
    ) -> CDnsResult<AsyncTaps>
    {
        let mut taps: AsyncTaps = AsyncTaps::new_with_capacity(self.ordered_req_list.len());

        let force_tcp: bool = 
            force_tcp == true || self.resolvers.option_flags.is_force_tcp() == true;

        if let Some(requery) = requery_list
        {
            let mut ids: HashSet<u16> = HashSet::with_capacity(requery.len());
            
            for mut req in requery
            {
                if self.resolvers.option_flags.is_reopen_socket() == true || taps.len() == 0
                { 
                    // create socket and store
                    let tap = 
                        self.create_socket(force_tcp, resolver, timeout)?;

                    // make sure the ID is uniq
                    loop
                    {
                        req.regenerate_id();

                        if ids.insert(req.get_id()) == true
                        {
                            break;
                        }
                    }

                    let t = Tap::new(tap, req);

                    taps.push(t);
                }
                else
                {
                    taps.push_to_last(req);
                }
            }
        }
        else
        {
            let mut ids: HashSet<u16> = HashSet::with_capacity(self.ordered_req_list.len());

            // form the list of reqests
            for req in self.ordered_req_list.iter()
            {
                let mut drh_req = DnsRequestHeader::from_qdns_req(req, self.resolvers.as_ref())?;

                if self.resolvers.option_flags.is_reopen_socket() == true || taps.len() == 0
                { 
                    // create socket and store
                    let tap = 
                        self.create_socket(force_tcp, resolver, timeout)?;
                    
                    // make sure the ID is uniq
                    loop
                    {
                        if ids.insert(drh_req.get_id()) == true
                        {
                            break;
                        }

                        drh_req.regenerate_id();
                    }

                    let t = Tap::new(tap, drh_req);

                    taps.push(t);
                }
                else
                {
                    taps.push_to_last(drh_req);
                }
            } // for
        }

        return Ok(taps);
    }

    /// Returns the [QDnsQueryRec::Ok] only if no other type did not occure
    fn get_result(responses: &LinkedList<QDnsQuery>) -> QDnsQueryRec
    {
        let mut resp = QDnsQueryRec::Ok;

        for r in responses.iter()
        {
            if r.is_ok() == false
            {
                resp = r.status;

                return resp;
            }
        }

        return resp;
    }

    /// Accesses the first record in responses and gets the Authorative flag
    fn get_authorative(responses: &LinkedList<QDnsQuery>) -> bool
    {
        return responses.front().map_or(false, |r| r.aa);
    }

    /// Quering all nameservers
    async 
    fn process_request(&mut self, now: Option<Instant>) -> CDnsResult<QDnsQueriesRes>
    {
        let mut qresponses: LinkedList<QDnsQuery> = LinkedList::new();

       // let mut resolved: HashSet<&'req QDnsReq> = HashSet::with_capacity(self.ordered_req_list.len());

        for resolver in self.resolvers.get_resolvers_iter()?
        {
            let qresp = 
                self.processing(now, resolver, None, false).await?;
            
            // check if results is ok
            let dnsres = Self::get_result(&qresp);
            let aa = Self::get_authorative(&qresp);

            qresponses.extend(qresp);

            if dnsres.try_next_nameserver(aa) == false
            {
                break;
            }
            
        }

        return Ok(QDnsQueriesRes::from(qresponses));
    }

    /// Processes request/s for nameserver
    #[async_recursion]
    async 
    fn processing(
        &self, 
        now: Option<Instant>, 
        resolver: &ResolveConfEntry,
        requery: Option<LinkedList<DnsRequestHeader>>,
        force_tcp: bool
    ) -> CDnsResult<LinkedList<QDnsQuery>>
    {
        let mut responses: LinkedList<QDnsQuery> = LinkedList::new();

        // a storage fot truncated request
        let mut truncated_list: LinkedList<DnsRequestHeader> = LinkedList::new();

        // form the list of reqests binded to taps
        let sync_taps: AsyncTaps = 
            self.create_sockets(resolver, requery, self.get_timeout(), force_tcp)?;

        // a list of JoinHandles
        let mut handles: Vec<JoinHandle<CDnsResult<Vec<SpawnFutereRes>>>> = Vec::with_capacity(sync_taps.len());

        for stap in sync_taps.into_iter()
        {
            // get socket
            //let socktap = stap.clone_tap();
            // generate uniq ID for reqeust
            //let req_list = stap.inner_requests();

            let (mut socktap, req_list) = stap.into_inner();

            if self.resolvers.option_flags.is_no_parallel() == true
            {
                // --- no parallel ----
                // connect socket
                socktap.connect().await?;

                for req in req_list
                {
                    let res = 
                        Self::spawn_future_single(
                            socktap.as_mut(),
                            req, 
                            now.map_or(None, |n| Some(n.clone()))
                        ).await;

                    match res
                    {
                        Ok(SpawnFutereRes::Ok(dq)) =>
                        {
                            responses.push_back(dq);
                        },
                        Ok(SpawnFutereRes::Truncated(drh)) =>
                        {
                            if force_tcp == true || self.resolvers.option_flags.is_force_tcp() == true
                            {
                                // throw error
                                internal_error!(CDnsErrorType::MessageTruncated, "Message is truncated even using TCP. Give up.");
                            }

                            truncated_list.push_back(drh);
                        },
                        Err(e) => return Err(e),
                    }
                }
            }
            else
            {
                // spawn async job
                let res: JoinHandle<CDnsResult<Vec<SpawnFutereRes>>> = 
                    tokio::spawn( 
                        async move { 
                            return 
                                Self::spawn_future_multi(
                                    socktap, 
                                    req_list, 
                                    now.map_or(None, |n| Some(n.clone()))
                                ).await; 
                        } 
                    );

                // store in list
                handles.push(res);
            }
        }

        // if the mode is parallel then this section will be exeuted
        for handle in handles
        {

            let res = 
                handle.await
                    .map_err(|e| 
                        internal_error_map!(CDnsErrorType::InternalError, "{}", e)
                    )?;
            
            let v_sfr = 
                match res
                {
                    Ok(r) => r,
                    Err(e) => return Err(e),
                };

            for sft in v_sfr
            {
                match sft
                {
                    SpawnFutereRes::Ok(dq) =>
                    {
                        responses.push_back(dq);
                    },
                    SpawnFutereRes::Truncated(drh) =>
                    {
                        if force_tcp == true || self.resolvers.option_flags.is_force_tcp() == true
                        {
                            // throw error
                            internal_error!(CDnsErrorType::MessageTruncated, "Message is truncated even using TCP. Give up.");
                        }

                        truncated_list.push_back(drh);
                    },
                }
            }
        }

        if truncated_list.is_empty() == false
        {
            // befor leaving we need to finish resolving what was truncated
            if force_tcp == true || self.resolvers.option_flags.is_force_tcp() == true
            {
                // throw error
                internal_error!(CDnsErrorType::MessageTruncated, "Message is truncated even using TCP. Give up.");
            }

            let res = 
                self.processing(now, resolver, Some(truncated_list), true).await?;

            responses.extend(res);
        }
        
        return Ok(responses);
    }

    /// Makes single request to the nameserver. `sock` myst be connected.
    async 
    fn spawn_future_single(sock: &mut NetworkTapType, req: DnsRequestHeader, now: Option<Instant>) -> CDnsResult<SpawnFutereRes>
    {
        // construct packet
        let pkt = req.async_to_bytes().await?;

        // send over wire
        sock.send(pkt.as_slice()).await?;

        // receive
        let mut rcvbuf = vec![0_u8; 1024];

        // receive message
        sock.recv(rcvbuf.as_mut_slice()).await?;

        // parsing response to structure
        let ans = DnsRequestAnswer::async_try_from(rcvbuf.as_slice()).await?;

        // verify the request with response
        match ans.verify(&req)
        {
            Ok(_) => {},
            Err(ref e) 
                if e.err_code == CDnsErrorType::MessageTruncated =>
            {
                // message was received truncated, then add to list
                // to resend query via TCP if not TCP
                async_write_error!("{}", e);

                return Ok(SpawnFutereRes::Truncated(req));
            },
            Err(e) => 
                return Err(e),
        }

        // verified
        let resp = QDnsQuery::from_response(sock.get_remote_addr(), ans, now.as_ref())?;

        return Ok(SpawnFutereRes::Ok(resp));
    }

    async 
    fn spawn_future_multi(mut sock: Box<NetworkTapType>, req_list: Vec<DnsRequestHeader>, now: Option<Instant>) -> CDnsResult<Vec<SpawnFutereRes>>
    {
        let mut spawn_res: Vec<SpawnFutereRes> = Vec::with_capacity(req_list.len());

        // connect socket
        sock.connect().await?;

        for req in req_list
        {
            let r = Self::spawn_future_single(sock.as_mut(), req, now.clone()).await?;

            spawn_res.push(r);
        }

        return Ok(spawn_res);
    }
}


#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_ip2pkt()
{
    use tokio::time::Instant;
    use std::net::{IpAddr, Ipv4Addr};
    
    let test = IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8));

    let now = Instant::now();

    let res = ip2pkt(&test);

    let elapsed = now.elapsed();
    println!("Elapsed: {:.2?}", elapsed);

    assert_eq!(res.is_ok(), true, "err: {}", res.err().unwrap());

    let res = res.unwrap();
    let ctrl = b"\x01\x38\x01\x38\x01\x38\x01\x38\x07\x69\x6e\x2d\x61\x64\x64\x72\x04\x61\x72\x70\x61\x00";

    assert_eq!(res.as_slice(), ctrl);
}


#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_byte2hexchar()
{
    assert_eq!(byte2hexchar(1), 0x31);
    assert_eq!(byte2hexchar(9), 0x39);
    assert_eq!(byte2hexchar(10), 'a' as u8);
    assert_eq!(byte2hexchar(15), 'f' as u8);
}


#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn reverse_lookup_test()
{
    use tokio::time::Instant;
    
    let ipp: IpAddr = "8.8.8.8".parse().unwrap();
    let test = QDnsName::from(&ipp);

    let resolvers = CACHE.clone_resolve_list().await.unwrap();
    let mut query_setup = QuerySetup::default();
    query_setup.set_measure_time(true);

    let now = Instant::now();

    let mut dns_req = 
        QDns::make_empty(resolvers, 1, query_setup);

    dns_req.add_request(QType::PTR, test);

    let res = dns_req.query().await;

    let elapsed = now.elapsed();
    println!("Elapsed: {:.2?}", elapsed);

    assert_eq!(res.is_results(), true);

    println!("{}", res);

    match res
    {
        QDnsQueriesRes::DnsOk{ res } =>
        {
            let rec = &res[0];

            //assert_eq!(rec.server.as_str(), "/etc/hosts");
            assert_eq!(rec.status, QDnsQueryRec::Ok);

            assert_eq!(rec.resp.len(), 1);
            assert_eq!(rec.resp[0].rdata, DnsRdata::PTR{ fqdn: "dns.google".to_string() });
            
        },
        _ => assert_eq!(true, false, "expected DnsResultSingle"),
    }
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn reverse_lookup_hosts_test()
{
    use tokio::time::Instant;

    let ipp: IpAddr = "127.0.0.1".parse().unwrap();
    let test = QDnsName::from(&ipp);
    
    let now = Instant::now();

    let mut query_setup = QuerySetup::default();
    query_setup.set_measure_time(true);

    let resolvers = CACHE.clone_resolve_list().await.unwrap();

    let mut dns_req = 
        QDns::make_empty(resolvers, 1, query_setup);

    dns_req.add_request(QType::PTR, test);

    let res = dns_req.query().await;
    
    let elapsed = now.elapsed();
    println!("Elapsed: {:.2?}", elapsed);

    assert_eq!(res.is_results(), true);

    println!("{}", res);

    match res
    {
        QDnsQueriesRes::DnsOk{ res } =>
        {
            let rec = &res[0];

            assert_eq!(rec.server.as_str(), "/etc/hosts");
            assert_eq!(rec.status, QDnsQueryRec::Ok);

            assert_eq!(rec.resp.len(), 1);
            assert_eq!(rec.resp[0].rdata, DnsRdata::PTR{ fqdn: "localhost".to_string() });
            
        },
        _ => assert_eq!(true, false, "expected DnsResultSingle"),
    }
}


#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn reverse_lookup_a()
{
    use tokio::time::Instant;

    let test = QDnsName::from("dns.google");

    let mut query_setup = QuerySetup::default();
    query_setup.set_measure_time(true);


    let resolvers = CACHE.clone_resolve_list().await.unwrap();
    
    let res = QDns::make_a_aaaa_request(resolvers, test, query_setup);
    
    
    let now = Instant::now();
    let res = res.query().await;
    

    let elapsed = now.elapsed();
    println!("Elapsed: {:.2?}", elapsed);

    println!("{}", res);
}

