/*-
 * cdns-rs - a simple sync/async DNS query library
 * Copyright (C) 2020  Aleksandr Morozov, RELKOM s.r.o
 * Copyright (C) 2021-2022  Aleksandr Morozov
 * 
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 *  file, You can obtain one at https://mozilla.org/MPL/2.0/.
 */

/// This file contains a networking code.

use std::io::{ErrorKind, Write};
use std::io::prelude::*;
use std::net::{IpAddr, SocketAddr};
use std::net::UdpSocket;
use std::os::unix::prelude::{AsRawFd, RawFd};
use std::time::Duration;
use core::fmt::Debug;

use std::net::TcpStream;

use nix::poll::{PollFd, PollFlags};

use socket2::{Socket, Domain, Type, Protocol, SockAddr};

use crate::{internal_error, internal_error_map};

use crate::error::*;


/// A common interface to access the realizations for both [TcpStream] and
/// [UdpSocket]
pub trait SocketTap
{
    /// Connects to the remote host.  
    /// If timeout is not set, the socket is initialized in non blocking mode
    ///  and [PollFd] is created to use with `poll(2)`.  
    /// If timeout is set, then the socket is initialized in blocking mode with
    ///  timeout. Tht [PollFd] is not generated!
    fn connect(&mut self, timeout: Option<Duration>) -> CDnsResult<()>;

    /// Tells if current instance is [TcpStream] if true, or [UdpSocket] if false
    fn is_tcp(&self) -> bool;

    /// Tells if socket/stream is connected to remote host
    fn is_conncected(&self) -> bool;

    /// Returns the [PollFd]. Will panic! if instance was initialized in blocking mode.
    fn get_pollfd(&self) -> PollFd;

    /// Returns the remote host Ip and port.
    fn get_remote_addr(&self) -> &SocketAddr;

    /// Sends data over wire.  
    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize> ;

    /// Receives data transmitted from remote host.
    /// In nonblocking mode it should be called only after the event was polled
    /// In blocking mode it will block until received or timeout.
    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>;
}


impl Debug for dyn SocketTap 
{
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result 
    {
        write!(f, "{:?}", self)
    }
}

/// An instance of the socket/stream.
#[derive(Debug)]
pub struct NetworkTap<T>
{
    /// Channel
    sock: Option<T>, 
    /// Nameserver address and port
    remote_addr: SocketAddr, 
    /// Local addriss to use to bind socket
    bind_addr: SocketAddr,
    /// Connection timeout
    timeout: Option<Duration>,
    /// [PollFd] from sock
    pfd: Option<PollFd>
}

impl SocketTap for NetworkTap<UdpSocket>
{
    fn connect(&mut self, timeout: Option<Duration>) -> CDnsResult<()>
    {
        if self.sock.is_some() == true
        {
            // ignore
            return Ok(());
        }

        let socket = 
            UdpSocket::bind(self.bind_addr)
                .map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;

        // set mode
        socket.set_nonblocking(timeout.is_none())
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // set timeout
        socket.set_read_timeout(timeout.clone())
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
        
        self.timeout = timeout;

        socket.connect(&self.remote_addr).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        if timeout.is_none() == true
        {
            // Creating PollFd structure. It is needed needed in async Poll mode.
            self.pfd = Some(create_poll_fd(socket.as_raw_fd()));
        }

        self.sock = Some(socket);

        return Ok(());
    }

    fn is_tcp(&self) -> bool 
    {
        return false;
    }

    fn is_conncected(&self) -> bool 
    {
        return self.sock.is_some();
    }

    fn get_pollfd(&self) -> PollFd 
    {
        return self.pfd.as_ref().unwrap().clone();
    }

    fn get_remote_addr(&self) -> &SocketAddr
    {
        return &self.remote_addr;
    }

    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize> 
    {
        return 
            self.sock.as_mut()
                .unwrap()
                .send(sndbuf)
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
    }

    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
    {
        let mut retry: u64 = 5;
        loop
        {
            match self.sock.as_mut().unwrap().recv_from(rcvbuf)
            {
                Ok((rcv_len, rcv_src)) =>
                {
                    // this should not fail because socket is "connected"
                    if rcv_src != self.remote_addr
                    {
                        internal_error!(
                            CDnsErrorType::DnsResponse, 
                            "received answer from unknown host: '{}' exp: '{}'", 
                            self.remote_addr, 
                            rcv_src
                        );
                    }

                    return Ok(rcv_len);
                },
                Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
                {
                    // timeout
                    if self.timeout.is_some() == true
                    {
                        // blocking mode
                        internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.remote_addr); 
                    }
                    else
                    {
                        // non blocking mode
                        if retry == 0
                        {
                            internal_error!(CDnsErrorType::IoError, "can not receive from: '{}'", self.remote_addr); 
                        }

                        retry -= 1;
                        continue;
                    }
                },
                Err(ref e) if e.kind() == ErrorKind::Interrupted =>
                {
                    continue;
                },
                Err(e) =>
                {
                    internal_error!(CDnsErrorType::IoError, "{}", e); 
                }
            } // match
        } // loop
    }
}

impl SocketTap for NetworkTap<TcpStream>
{
    fn connect(&mut self, timeout: Option<Duration>) -> CDnsResult<()> 
    {
        if self.sock.is_some() == true
        {
            // ignore
            return Ok(());
        }

        // create socket
        let socket = 
            Socket::new(Domain::for_address(self.remote_addr), Type::STREAM, Some(Protocol::TCP))
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // bind socket
        socket.bind(&SockAddr::from(self.bind_addr))
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // connect
        socket.connect(&SockAddr::from(self.remote_addr))
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // convert to TcpStream
        let socket: TcpStream = socket.into();

        // set mode
        socket.set_nonblocking(timeout.is_none())
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;

        // set timeout
        socket.set_read_timeout(timeout.clone())
            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
        
        self.timeout = timeout;

        if timeout.is_none() == true
        {
            // Creating PollFd structure. It is needed needed in async Poll mode.
            self.pfd = Some(create_poll_fd(socket.as_raw_fd()));
        }

        self.sock = Some(socket);

        return Ok(());
    }

    fn is_tcp(&self) -> bool 
    {
        return true;
    }

    fn is_conncected(&self) -> bool 
    {
        return self.sock.is_some();
    }

    fn get_pollfd(&self) -> PollFd 
    {
        return self.pfd.as_ref().unwrap().clone();
    }

    fn get_remote_addr(&self) -> &SocketAddr
    {
        return &self.remote_addr;
    }

    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>  
    {
        return 
            self.sock
                .as_mut()
                .unwrap()
                .write(sndbuf)
                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
    }

    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
    {
        let mut retry: u64 = 5;
        loop
        {
            match self.sock.as_ref().unwrap().read(rcvbuf)
            {
                Ok(n) => 
                {
                    return Ok(n);
                },
                Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
                {
                    // timeout
                    if self.timeout.is_some() == true
                    {
                        // blocking mode
                        internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.remote_addr); 
                    }
                    else
                    {
                        // non blocking mode
                        if retry == 0
                        {
                            internal_error!(CDnsErrorType::IoError, "can not receive from: '{}'", self.remote_addr); 
                        }

                        retry -= 1;
                        continue;
                    }
                },
                Err(ref e) if e.kind() == ErrorKind::Interrupted =>
                {
                    continue;
                },
                Err(e) =>
                {
                    internal_error!(CDnsErrorType::IoError, "{}", e); 
                }
            }
        }
    }
}

impl<T> NetworkTap<T>
{
    
}


/// Creates new instance of [UdpSocket]. 
/// The socket is not connected and [SocketTap::connect] should be called.
/// 
/// # Arguments
/// 
/// * `resolver_ip` - a ref to [IpAddr] which holds host address of the nameserver.
/// 
/// * `resolver_port` - a port number binded by nameserver
/// 
/// * `bind_addr` - a local address to bind the socket to
/// 
/// # Returns
/// 
/// * [CDnsResult] - Ok with inner type [Box] dyn [SocketTap]
/// 
/// * [CDnsResult] - Err with error description
pub 
fn new_udp(
    resolver_ip: &IpAddr, 
    resolver_port: u16, 
    bind_addr: &SocketAddr
) -> CDnsResult<Box<dyn SocketTap>>
{
    // setting address and port
    let remote_dns_host = SocketAddr::from((resolver_ip.clone(), resolver_port));

    let ret = 
        NetworkTap::<UdpSocket>
        { 
            sock: None, 
            remote_addr: remote_dns_host, 
            bind_addr: bind_addr.clone(),
            pfd: None,
            timeout: None
        };
    
    return Ok( Box::new(ret) );
}

/// Creates new instance of [TcpStream]. 
/// The stream is not connected and [SocketTap::connect] should be called.
/// 
/// # Arguments
/// 
/// * `resolver_ip` - a ref to [IpAddr] which holds host address of the nameserver.
/// 
/// * `resolver_port` - a port number binded by nameserver
/// 
/// * `bind_addr` - a local address to bind the socket to
/// 
/// # Returns
/// 
/// * [CDnsResult] - Ok with inner type [Box] dyn [SocketTap]
/// 
/// * [CDnsResult] - Err with error description
pub 
fn new_tcp(
    resolver_ip: &IpAddr, 
    resolver_port: u16, 
    bind_addr: &SocketAddr
) -> CDnsResult<Box<dyn SocketTap>>
{
    // setting address and port
    let remote_dns_host = SocketAddr::from((resolver_ip.clone(), resolver_port));

    let ret = 
        NetworkTap::<TcpStream>
        { 
            sock: None, 
            remote_addr: remote_dns_host,
            bind_addr: bind_addr.clone(),
            timeout: None,
            pfd: None
        };

    return Ok( Box::new(ret) );
}

/// Creates pollfd from the [RawFd]
fn create_poll_fd(fd: RawFd) -> PollFd
{
    return 
        PollFd::new(
            fd, 
            PollFlags::from(PollFlags::POLLIN)
        );
}


#[test]
fn test_struct()
{
    use super::common::IPV4_BIND_ALL;
    
    let ip0: IpAddr = "127.0.0.1".parse().unwrap();
    let bind =  SocketAddr::from((IPV4_BIND_ALL, 0));
    let res = new_udp(&ip0, 53, &bind);

    assert_eq!(res.is_ok(), true, "{}", res.err().unwrap());

    let _res = res.unwrap();
}
