/*-
* cdns-rs - a simple sync/async DNS query library
* Copyright (C) 2020  Aleksandr Morozov, RELKOM s.r.o
* Copyright (C) 2021-2022  Aleksandr Morozov
* 
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* Lesser General Public License for more details.
* 
* You should have received a copy of the GNU Lesser General Public License
* along with this program; if not, write to the Free Software Foundation,
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
*/

/// This file contains a networking code.

use std::io::ErrorKind;
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;
use core::fmt::Debug;

use async_trait::async_trait;

use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tokio::net::{UdpSocket, TcpStream};

use socket2::{Socket, Domain, Type, Protocol, SockAddr};


use crate::{internal_error, internal_error_map};

use crate::error::*;

/// A common interface to access the realizations for both [TcpStream] and
/// [UdpSocket]
#[async_trait]
pub trait SocketTap
{
    /// Connects to the remote host.  
    /// If timeout is not set, the socket is initialized in non blocking mode
    ///  and [PollFd] is created to use with `poll(2)`.  
    /// If timeout is set, then the socket is initialized in blocking mode with
    ///  timeout. Tht [PollFd] is not generated!
    async fn connect(&mut self) -> CDnsResult<()>;

    /// Tells if current instance is [TcpStream] if true, or [UdpSocket] if false
    fn is_tcp(&self) -> bool;

    /// Tells if socket/stream is connected to remote host
    async fn is_conncected(&self) -> bool;

    /// Returns the remote host Ip and port.
    fn get_remote_addr(&self) -> &SocketAddr;

    /// Sends data over wire.  
    async fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize> ;

    /// Receives data transmitted from remote host.
    /// In nonblocking mode it should be called only after the event was polled
    /// In blocking mode it will block until received or timeout.
    async fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>;
}


impl Debug for dyn SocketTap + Send + Sync
{
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result 
    {
        write!(f, "{:?}", self)
    }
}

pub type NetworkTapType = (dyn SocketTap + Send + Sync);


/// An instance of the socket/stream.
#[derive(Debug)]
pub struct NetworkTap<T>
{
    /// Channel
    sock: Option<T>, 
    /// Nameserver address and port
    remote_addr: SocketAddr, 
    /// Local addriss to use to bind socket
    bind_addr: SocketAddr,
    /// Connection timeout
    timeout: Duration,
}

unsafe impl<T> Send for NetworkTap<T> {}
unsafe impl<T> Sync for NetworkTap<T> {}

#[async_trait]
impl SocketTap for NetworkTap<UdpSocket>
{
    async
    fn connect(&mut self) -> CDnsResult<()>
    {
        if self.sock.is_some() == true
        {
            // ignore
            return Ok(());
        }

        let socket = 
            UdpSocket::bind(self.bind_addr)
                .await
                .map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;

        socket.connect(&self.remote_addr)
            .await
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        self.sock = Some(socket);

        return Ok(());
    }

    fn is_tcp(&self) -> bool 
    {
        return false;
    }

    async 
    fn is_conncected(&self) -> bool 
    {
        return self.sock.is_some();
    }

    fn get_remote_addr(&self) -> &SocketAddr
    {
        return &self.remote_addr;
    }

    async
    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize> 
    {
        return 
            self.sock.as_mut()
                .unwrap()
                .send(sndbuf)
                .await
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
    }

    async 
    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
    {
        async 
        fn sub_recv(this: &mut NetworkTap<UdpSocket>, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
        {
            loop
            {
                match this.sock.as_mut().unwrap().recv_from(rcvbuf).await
                {
                    Ok((rcv_len, rcv_src)) =>
                    {
                        // this should not fail because socket is "connected"
                        if rcv_src != this.remote_addr
                        {
                            internal_error!(
                                CDnsErrorType::DnsResponse, 
                                "received answer from unknown host: '{}' exp: '{}'", 
                                this.remote_addr, 
                                rcv_src
                            );
                        }

                        return Ok(rcv_len);
                    },
                    Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
                    {
                        continue;
                    },
                    Err(ref e) if e.kind() == ErrorKind::Interrupted =>
                    {
                        continue;
                    },
                    Err(e) =>
                    {
                        internal_error!(CDnsErrorType::IoError, "{}", e); 
                    }
                } // match
            } // loop
            
        }

        // wait for timeout
        match timeout(self.timeout, sub_recv(self, rcvbuf)).await
        {
            Ok(r) => return r,
            Err(e) => internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
        }
    }
}

#[async_trait]
impl SocketTap for NetworkTap<TcpStream>
{
    async 
    fn connect(&mut self) -> CDnsResult<()> 
    {
        if self.sock.is_some() == true
        {
            // ignore
            return Ok(());
        }

        // create socket
        let socket = 
            Socket::new(Domain::for_address(self.remote_addr), Type::STREAM, Some(Protocol::TCP))
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // bind socket
        socket.bind(&SockAddr::from(self.bind_addr))
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // connect
        socket.connect(&SockAddr::from(self.remote_addr))
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // convert to TcpStream
        let socket: TcpStream = 
            TcpStream::from_std(socket.into())
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        self.sock = Some(socket);

        return Ok(());
    }

    fn is_tcp(&self) -> bool 
    {
        return true;
    }

    async 
    fn is_conncected(&self) -> bool 
    {
        return self.sock.is_some();
    }

    fn get_remote_addr(&self) -> &SocketAddr
    {
        return &self.remote_addr;
    }

    async 
    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>  
    {
        return 
            self.sock.as_mut()
                .unwrap()
                .write(sndbuf)
                .await
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
    }

    async 
    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
    {
        async 
        fn sub_recv(this: &mut NetworkTap<TcpStream>, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
        {
            loop
            {
                match this.sock.as_mut().unwrap().read(rcvbuf).await
                {
                    Ok(n) => 
                    {
                        return Ok(n);
                    },
                    Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
                    {
                        continue;
                    },
                    Err(ref e) if e.kind() == ErrorKind::Interrupted =>
                    {
                        continue;
                    },
                    Err(e) =>
                    {
                        internal_error!(CDnsErrorType::IoError, "{}", e); 
                    }
                } // match
            } // loop
        }

        // wait for timeout
        match timeout(self.timeout, sub_recv(self, rcvbuf)).await
        {
            Ok(r) => return r,
            Err(e) => internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
        }
    }
}

impl<T> NetworkTap<T>
{
    
}


/// Creates new instance of [UdpSocket]. 
/// The socket is not connected and [SocketTap::connect] should be called.
/// 
/// # Arguments
/// 
/// * `resolver_ip` - a ref to [IpAddr] which holds host address of the nameserver.
/// 
/// * `resolver_port` - a port number binded by nameserver
/// 
/// * `bind_addr` - a local address to bind the socket to
/// 
/// # Returns
/// 
/// * [CDnsResult] - Ok with inner type [Box] dyn [SocketTap]
/// 
/// * [CDnsResult] - Err with error description
pub 
fn new_udp(
    resolver_ip: &IpAddr, 
    resolver_port: u16, 
    bind_addr: &SocketAddr,
    timeout: Duration
) -> CDnsResult<Box<NetworkTapType>>
{
    // setting address and port
    let remote_dns_host = SocketAddr::from((resolver_ip.clone(), resolver_port));

    let ret = 
        NetworkTap::<UdpSocket>
        { 
            sock: None, 
            remote_addr: remote_dns_host, 
            bind_addr: bind_addr.clone(),
            timeout: timeout
        };
    
    return Ok( Box::new(ret) );
}

/// Creates new instance of [TcpStream]. 
/// The stream is not connected and [SocketTap::connect] should be called.
/// 
/// # Arguments
/// 
/// * `resolver_ip` - a ref to [IpAddr] which holds host address of the nameserver.
/// 
/// * `resolver_port` - a port number binded by nameserver
/// 
/// * `bind_addr` - a local address to bind the socket to
/// 
/// # Returns
/// 
/// * [CDnsResult] - Ok with inner type [Box] dyn [SocketTap]
/// 
/// * [CDnsResult] - Err with error description
pub 
fn new_tcp(
    resolver_ip: &IpAddr, 
    resolver_port: u16, 
    bind_addr: &SocketAddr,
    timeout: Duration
) -> CDnsResult<Box<NetworkTapType>>
{
    // setting address and port
    let remote_dns_host = SocketAddr::from((resolver_ip.clone(), resolver_port));

    let ret = 
        NetworkTap::<TcpStream>
        { 
            sock: None, 
            remote_addr: remote_dns_host,
            bind_addr: bind_addr.clone(),
            timeout: timeout,
        };

    return Ok( Box::new(ret) );
}



#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_struct()
{
    use crate::a_sync::common::IPV4_BIND_ALL;
    
    let ip0: IpAddr = "127.0.0.1".parse().unwrap();
    let bind =  SocketAddr::from((IPV4_BIND_ALL, 0));
    let res = new_udp(&ip0, 53, &bind, Duration::from_secs(5));

    assert_eq!(res.is_ok(), true, "{}", res.err().unwrap());

    let _res = res.unwrap();
}
