//! DHT node

use crate::{
    directory::Directory,
    id::{Id, InfoHash},
    message::{
        ErrorCode, ErrorMessage, ExpectedResponse, Message, Port, QueryMessage, Response,
        ResponseMessage,
    },
    node_info::NodeInfo,
    table::Table,
    timer::Timer,
    token::Token,
    transport::{Transport, TransportExt},
    txid::TxId,
    Error, Result,
};
use async_stream::try_stream;
use futures::{
    channel::{
        mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
        oneshot::{channel, Receiver, Sender},
    },
    pending, poll, ready,
    stream::{FuturesUnordered, Stream, StreamExt},
};
use pin_project::pin_project;
use rand::random;
use serde_bencode::{from_bytes as bdecode, to_bytes as bencode};
use std::{
    borrow::Borrow,
    cmp::Ordering,
    collections::{btree_map::Entry, BTreeMap, BTreeSet},
    future::Future,
    net::SocketAddrV4,
    pin::Pin,
    result::Result as StdResult,
    task::Poll::Ready,
    task::{Context, Poll},
    time::Duration,
};

/// A DHT node
pub struct Node {
    transport: Pin<Box<dyn Transport + Unpin>>,

    directory: Directory,
    table: Table,

    purge_queries_timer: Timer,
    refresh_buckets_timer: Timer,
    expire_peers_timer: Timer,

    query_handle: QueryHandle,
    query_queue: UnboundedReceiver<QuerySlot>,

    closest_queue: UnboundedReceiver<(Id, Sender<Vec<NodeInfo>>)>,

    pending_replies: Vec<(Message, SocketAddrV4)>,
    pending_queries: BTreeSet<QuerySlot>,
    pending_tokens: BTreeSet<Token>,

    available_tokens: BTreeSet<Token>,
}

impl Node {
    /// Creates a new node over the given transport
    pub fn new<T>(my_id: Id, transport: T) -> (Self, SearchHandle)
    where
        T: Transport + Unpin + 'static,
    {
        let (query_sender, query_queue) = unbounded();
        let (closest_sender, closest_queue) = unbounded();
        let query_handle = QueryHandle { query_sender };
        (
            Self {
                transport: Box::pin(transport),
                purge_queries_timer: Timer::new(Self::PURGE_QUERIES_INTERVAL),
                refresh_buckets_timer: Timer::new(Self::REFRESH_BUCKETS_INTERVAL),
                expire_peers_timer: Timer::new(Self::EXPIRE_PEERS_INTERVAL),
                directory: Directory::new(my_id, query_handle.clone()),
                table: Table::new(),
                query_handle: query_handle.clone(),
                query_queue,
                closest_queue,
                pending_replies: Vec::new(),
                pending_queries: BTreeSet::new(),
                pending_tokens: BTreeSet::new(),
                available_tokens: BTreeSet::new(),
            },
            SearchHandle {
                my_id,
                query_handle,
                closest_sender,
            },
        )
    }

    #[cfg(test)]
    fn get_query_handle(&self) -> QueryHandle {
        self.query_handle.clone()
    }

    /// Runs the node
    pub async fn run(mut self) -> Result<()> {
        self.preallocate_tokens();
        loop {
            self.do_receive().await?;
            self.do_send().await?;
            self.do_maintenance().await?;
            pending!();
        }
    }

    async fn do_receive(&mut self) -> Result<()> {
        let my_id = self.directory.id();
        let mut buf = [0u8; Self::MAX_DATAGRAM_LEN];
        while let Ready(res) = poll!(self.transport.receive(&mut buf)) {
            let (n, address) = res?;
            let buf = &buf[..n];
            if let Ok(message) = bdecode(buf) {
                match message {
                    Message::Query { txid, message } => match message {
                        QueryMessage::Ping { id } => {
                            self.found_node(id, address);
                            let message = Message::Response {
                                txid,
                                message: Response::Pong { id: my_id }.into(),
                            };
                            self.pending_replies.push((message, address));
                        }
                        QueryMessage::FindNode { id, target } => {
                            self.found_node(id, address);
                            let nodes = self.directory.closest_nodes(target);
                            let message = Message::Response {
                                txid,
                                message: Response::ClosestNodes { id: my_id, nodes }.into(),
                            };
                            self.pending_replies.push((message, address));
                        }
                        QueryMessage::GetPeers { id, info_hash } => {
                            self.found_node(id, address);
                            let token = self.get_token();
                            let message = match self.table.get_addresses(&info_hash) {
                                Some(peers) => Message::Response {
                                    txid,
                                    message: Response::KnownPeers {
                                        id: my_id,
                                        token: token.clone(),
                                        peers,
                                    }
                                    .into(),
                                },
                                None => {
                                    let nodes = self.directory.closest_nodes(info_hash);
                                    Message::Response {
                                        txid,
                                        message: Response::NoKnownPeers {
                                            id: my_id,
                                            token: token.clone(),
                                            nodes,
                                        }
                                        .into(),
                                    }
                                }
                            };
                            self.pending_replies.push((message, address));
                            self.pending_tokens.insert(token);
                        }
                        QueryMessage::AnnouncePeer {
                            id,
                            token,
                            info_hash,
                            port,
                        } => {
                            self.found_node(id, address);
                            let message = match self.pending_tokens.take(&token) {
                                Some(_) => {
                                    let port = port.unwrap_or(address.port());
                                    let address = SocketAddrV4::new(*address.ip(), port);
                                    self.table.insert(info_hash, address);
                                    Message::Response {
                                        txid,
                                        message: Response::Announced { id: my_id }.into(),
                                    }
                                }
                                None => Message::Error {
                                    txid,
                                    message: ErrorMessage::new(
                                        ErrorCode::ProtocolError,
                                        "bad token",
                                    ),
                                },
                            };
                            self.pending_replies.push((message, address));
                        }
                    },
                    Message::Response { txid, message } => {
                        if let Some(slot) = self.pending_queries.take(&txid) {
                            self.found_node(message.id(), address);
                            slot.answer(message);
                        }
                    }
                    Message::Error { txid, message } => {
                        if let Some(slot) = self.pending_queries.take(&txid) {
                            slot.fail(message);
                        }
                    }
                }
            }
        }
        Ok(())
    }

    async fn do_send(&mut self) -> Result<()> {
        while let Ok(Some(mut slot)) = self.query_queue.try_next() {
            slot.message.set_my_id(self.directory.id());
            let buf = bencode(&slot.message).expect("invalid message");
            self.transport.send(&buf, slot.address).await?;
            self.pending_queries.insert(slot);
        }
        for (mut message, address) in self.pending_replies.drain(..) {
            message.set_my_id(self.directory.id());
            let buf = bencode(&message).expect("invalid message");
            self.transport.send(&buf, address).await?;
        }
        Ok(())
    }

    async fn do_maintenance(&mut self) -> Result<()> {
        while let Ok(Some((id, reply))) = self.closest_queue.try_next() {
            let closest = self.directory.closest_nodes(id);
            let _ = reply.send(closest);
        }
        if self.purge_queries_timer.is_expired() {
            let mut retried = vec![];
            let expired = self.pending_queries.drain_filter(QuerySlot::is_expired);
            for mut slot in expired {
                if slot.exhausted_retries() {
                    slot.timeout();
                } else {
                    slot.mark_retry();
                    retried.push(slot);
                }
            }
            self.pending_queries.extend(retried);
            self.purge_queries_timer.restart();
        }
        if self.refresh_buckets_timer.is_expired() {
            self.directory.refresh();
            self.refresh_buckets_timer.restart();
        }
        if self.expire_peers_timer.is_expired() {
            self.table.expire();
            self.expire_peers_timer.restart();
        }
        Ok(())
    }

    fn preallocate_tokens(&mut self) {
        while self.available_tokens.len() < Self::TOKEN_BATCH_SIZE {
            self.available_tokens.insert(random());
        }
    }

    fn get_token(&mut self) -> Token {
        loop {
            match self.available_tokens.pop_first() {
                Some(token) => {
                    break token;
                }
                None => self.preallocate_tokens(),
            }
        }
    }

    fn found_node(&mut self, id: Id, address: SocketAddrV4) {
        self.directory.add_node(NodeInfo::new(id, address));
    }

    const MAX_DATAGRAM_LEN: usize = 1024;
    const TOKEN_BATCH_SIZE: usize = 64;
    const PURGE_QUERIES_INTERVAL: Duration = Duration::from_secs(10);
    const REFRESH_BUCKETS_INTERVAL: Duration = Duration::from_secs(60);
    const EXPIRE_PEERS_INTERVAL: Duration = Duration::from_secs(60);
}

/// A handle for performing searches
#[derive(Clone)]
pub struct SearchHandle {
    my_id: Id,
    query_handle: QueryHandle,
    closest_sender: UnboundedSender<(Id, Sender<Vec<NodeInfo>>)>,
}

/// Announcement setting
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum Announce {
    /// Do not announce
    No,
    /// Announce on the same port as DHT
    SamePort,
    /// Announce on a specific port
    OnPort(u16),
}

impl SearchHandle {
    /// Performs a search for peers downloading a torrent, optionally announcing yourself
    pub fn search(
        &mut self,
        info_hash: InfoHash,
        announce: Announce,
    ) -> impl Stream<Item = Result<SocketAddrV4>> + '_ {
        try_stream! {
            let (sender, receiver) = channel();
            self.closest_sender.unbounded_send((info_hash, sender))?;
            let nodes = receiver.await.expect("oneshot cancelled");
            let mut found_peers = BTreeSet::new();
            let mut pending: BTreeMap<_, _> = nodes
                .into_iter()
                .map(|node| (node, SearchStatus::Waiting))
                .collect();
            let mut queries = pending
                .iter()
                .map(|(node, _)| self.query_handle.get_peers(info_hash, node.address()))
                .collect::<Result<FuturesUnordered<_>>>()?;

            let mut queries_done = false;
            while !queries_done && !fully_resolved(&pending) {
                match queries.next().await {
                    Some(result) => match result {
                        Ok(Response::KnownPeers { id, peers, token }) => {
                            resolve_node(&mut pending, id, token);
                            for peer in peers {
                                if found_peers.insert(peer) {
                                    yield peer;
                                }
                            }
                        }
                        Ok(Response::NoKnownPeers { id, nodes, token }) => {
                            resolve_node(&mut pending, id, token);
                            for node in nodes {
                                if node.id() == self.my_id
                                    || !insert_search_node(&mut pending, node)
                                {
                                    continue;
                                }
                                if let Ok(query) =
                                    self.query_handle.get_peers(info_hash, node.address())
                                {
                                    queries.push(query);
                                }
                            }
                        }
                        _ => {}
                    },
                    None => queries_done = true,
                }
            }

            for (node, status) in pending {
                if let SearchStatus::Resolved(token) = status {
                    if announce != Announce::No {
                        let port = match announce {
                            Announce::SamePort => Port::Implicit,
                            Announce::OnPort(port) => Port::Explicit(port),
                            _ => unreachable!(),
                        };
                        let _ =
                            self.query_handle
                                .announce_peer(token, info_hash, port, node.address());
                    }
                }
            }
        }
    }
}

#[derive(Clone, Debug, PartialEq, Eq)]
enum SearchStatus {
    Waiting,
    Resolved(Token),
}

fn insert_search_node(nodes: &mut BTreeMap<NodeInfo, SearchStatus>, node: NodeInfo) -> bool {
    if nodes.len() > 2 * Directory::BUCKET_SIZE && nodes.last_key_value().unwrap().0 < &node {
        return false;
    }
    let entry = nodes.entry(node);
    if let Entry::Occupied(o) = &entry {
        if matches!(o.get(), SearchStatus::Resolved(_)) {
            return false;
        }
    }
    entry.or_insert(SearchStatus::Waiting);
    while nodes.len() > 2 * Directory::BUCKET_SIZE {
        nodes.pop_last(); // TODO return false if popped?
    }

    let entry = nodes.insert(node, SearchStatus::Waiting);
    true
}

fn resolve_node(nodes: &mut BTreeMap<NodeInfo, SearchStatus>, id: Id, token: Token) {
    nodes
        .get_mut(&id)
        .map(|s| *s = SearchStatus::Resolved(token));
}

fn fully_resolved(nodes: &BTreeMap<NodeInfo, SearchStatus>) -> bool {
    nodes
        .iter()
        .take(Directory::BUCKET_SIZE)
        .all(|(_, replied)| matches!(replied, SearchStatus::Resolved(_)))
}

/// A handle for performing queries
#[derive(Clone)]
pub struct QueryHandle {
    query_sender: UnboundedSender<QuerySlot>,
}

impl QueryHandle {
    /// Performs a ping query
    pub fn ping(&mut self, address: SocketAddrV4) -> Result<ReplyHandle> {
        self.query(QueryMessage::ping(), address)
    }

    /// Performs a find_node query
    pub fn find_node(&mut self, target: Id, address: SocketAddrV4) -> Result<ReplyHandle> {
        self.query(QueryMessage::find_node(target), address)
    }

    /// Performs a get_peers query
    pub fn get_peers(&mut self, info_hash: Id, address: SocketAddrV4) -> Result<ReplyHandle> {
        self.query(QueryMessage::get_peers(info_hash), address)
    }

    /// Performs an announce_peer query
    pub fn announce_peer(
        &mut self,
        token: Token,
        info_hash: Id,
        port: Port,
        address: SocketAddrV4,
    ) -> Result<ReplyHandle> {
        self.query(QueryMessage::announce_peer(token, info_hash, port), address)
    }

    /// Performs a query and returns a handle to receive the reply
    fn query(&mut self, message: QueryMessage, address: SocketAddrV4) -> Result<ReplyHandle> {
        let (slot, reply_receiver) = QuerySlot::new(message, address);
        self.query_sender.unbounded_send(slot)?;
        Ok(ReplyHandle { reply_receiver })
    }
}

/// A handle for receiving replies to queries
#[pin_project]
pub struct ReplyHandle {
    #[pin]
    reply_receiver: Receiver<QueryResult>,
}

impl Future for ReplyHandle {
    type Output = Result<Response>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<Response>> {
        let this = self.project();
        Ready(Ok(
            ready!(this.reply_receiver.poll(cx)).expect("cancelled reply")?
        ))
    }
}

type QueryResult = StdResult<Response, QueryError>;

struct QuerySlot {
    txid: TxId,
    message: Message,
    expected_response: ExpectedResponse,
    address: SocketAddrV4,
    reply: Sender<QueryResult>,
    timer: Timer,
    attempts: u64,
}

impl QuerySlot {
    pub fn new(message: QueryMessage, address: SocketAddrV4) -> (Self, Receiver<QueryResult>) {
        let txid = random();
        let (sender, receiver) = channel();
        let expected_response = message.expected_response();
        (
            Self {
                txid,
                message: Message::Query { txid, message },
                expected_response,
                address,
                reply: sender,
                timer: Timer::new(Self::QUERY_TIMEOUT),
                attempts: 1,
            },
            receiver,
        )
    }

    pub fn is_expired(&self) -> bool {
        self.timer.is_expired()
    }

    pub fn exhausted_retries(&self) -> bool {
        self.attempts >= Self::MAX_ATTEMPTS
    }

    pub fn mark_retry(&mut self) {
        self.attempts += 1;
        self.timer.restart();
    }

    pub fn answer(self, response: ResponseMessage) {
        let parsed = response
            .parse(self.expected_response)
            .map_err(|_| ErrorMessage::protocol("non-matching response").into());
        let _ = self.reply.send(parsed);
    }

    pub fn fail(self, error: ErrorMessage) {
        let _ = self.reply.send(Err(QueryError::Reply(error)));
    }

    pub fn timeout(self) {
        let _ = self.reply.send(Err(QueryError::Timeout));
    }

    const MAX_ATTEMPTS: u64 = 5;
    const QUERY_TIMEOUT: Duration = Duration::from_secs(60);
}

impl PartialEq for QuerySlot {
    fn eq(&self, other: &Self) -> bool {
        self.txid == other.txid
    }
}

impl Eq for QuerySlot {}

impl PartialOrd for QuerySlot {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        self.txid.partial_cmp(&other.txid)
    }
}

impl Ord for QuerySlot {
    fn cmp(&self, other: &Self) -> Ordering {
        self.txid.cmp(&other.txid)
    }
}

impl Borrow<TxId> for QuerySlot {
    fn borrow(&self) -> &TxId {
        &self.txid
    }
}

#[derive(Clone, Debug)]
pub enum QueryError {
    Timeout,
    Reply(ErrorMessage),
}

impl From<QueryError> for Error {
    fn from(e: QueryError) -> Self {
        match e {
            QueryError::Timeout => Self::Timeout,
            QueryError::Reply(m) => Self::from(m),
        }
    }
}

impl From<ErrorMessage> for QueryError {
    fn from(e: ErrorMessage) -> Self {
        Self::Reply(e)
    }
}

pub struct ResponseSlot {
    txid: TxId,
    query: QueryMessage,
    address: SocketAddrV4,
    sender: UnboundedSender<(Message, SocketAddrV4)>,
}

impl ResponseSlot {
    fn new(
        txid: TxId,
        query: QueryMessage,
        address: SocketAddrV4,
        sender: UnboundedSender<(Message, SocketAddrV4)>,
    ) -> Self {
        Self {
            txid,
            query,
            address,
            sender,
        }
    }

    pub fn query(&self) -> &QueryMessage {
        &self.query
    }

    pub fn address(&self) -> SocketAddrV4 {
        self.address
    }

    pub fn answer(self, message: ResponseMessage) {
        let response = Message::Response {
            txid: self.txid,
            message,
        };
        let _ = self.sender.unbounded_send((response, self.address));
    }

    pub fn fail(self, message: ErrorMessage) {
        let error = Message::Error {
            txid: self.txid,
            message,
        };
        let _ = self.sender.unbounded_send((error, self.address));
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::{
        message::{ErrorCode, Port},
        testnet::TestNet,
    };
    use futures::{executor::block_on, future::abortable, join, pin_mut};

    #[test]
    fn nodes_can_search() {
        let mut net = TestNet::new();
        let sock1 = net.join();
        let sock2 = net.join();
        let sock3 = net.join();
        let sock4 = net.join();
        let addr1 = sock1.address();
        let addr2 = sock2.address();
        let addr3 = sock3.address();
        let addr4 = sock4.address();
        let id1 = random();
        let id2 = random();
        let id3 = random();
        let id4 = random();
        let (node1, mut search1) = Node::new(id1, sock1);
        let (node2, _) = Node::new(id2, sock2);
        let (node3, _) = Node::new(id3, sock3);
        let (node4, _) = Node::new(id4, sock4);
        let mut query1 = node1.get_query_handle();
        let mut query2 = node2.get_query_handle();
        let mut query3 = node3.get_query_handle();
        let mut query4 = node4.get_query_handle();

        let run0 = net.run();
        let run1 = node1.run();
        let run2 = node2.run();
        let run3 = node3.run();
        let run4 = node4.run();

        let (run, abort) = abortable(async move { join!(run0, run1, run2, run3, run4) });

        let replies = async move {
            let r = query1.ping(addr2).unwrap().await.unwrap();
            assert_eq!(r, Response::Pong { id: id2 });

            let r = query2.ping(addr3).unwrap().await.unwrap();
            assert_eq!(r, Response::Pong { id: id3 });

            let r = query3.ping(addr4).unwrap().await.unwrap();
            assert_eq!(r, Response::Pong { id: id4 });

            let r = query4.ping(addr1).unwrap().await.unwrap();
            assert_eq!(r, Response::Pong { id: id1 });

            let r = query1.find_node(id3, addr2).unwrap().await.unwrap();
            assert_eq!(
                r,
                Response::ClosestNodes {
                    id: id2,
                    nodes: vec![NodeInfo::new(id3, addr3), NodeInfo::new(id1, addr1)]
                }
            );

            let r = query1.ping(addr3).unwrap().await.unwrap();
            assert_eq!(r, Response::Pong { id: id3 });

            let info_hash = random();
            let r = query4.get_peers(info_hash, addr2).unwrap().await.unwrap();
            assert!(matches!(&r, Response::NoKnownPeers {id: id2, nodes, ..} if !nodes.is_empty()));
            let token = r.token().unwrap();

            let r = query4
                .announce_peer(token.clone(), info_hash, Port::Implicit, addr2)
                .unwrap()
                .await
                .unwrap();
            assert_eq!(r, Response::Announced { id: id2 });

            let r = query3.get_peers(info_hash, addr2).unwrap().await.unwrap();
            let expected_peers = vec![addr1];
            assert!(matches!(
                &r,
                Response::KnownPeers {
                    id: id2,
                    peers: expected_peers,
                    ..
                }
            ));

            let e = query3
                .announce_peer(token.clone(), info_hash, Port::Implicit, addr2)
                .unwrap()
                .await
                .unwrap_err();
            assert!(matches!(e, Error::ErrorReply(e) if e.code() == ErrorCode::ProtocolError));
            let token = r.token().unwrap();

            let r = query3
                .announce_peer(token.clone(), info_hash, Port::Implicit, addr2)
                .unwrap()
                .await
                .unwrap();
            assert_eq!(r.id(), id2);

            let r = query4.get_peers(info_hash, addr2).unwrap().await.unwrap();
            let expected_peers = vec![addr3, addr4];
            assert!(matches!(
                r,
                Response::KnownPeers {
                    id: id2,
                    peers: expected_peers,
                    ..
                }
            ));

            {
                let s = search1.search(info_hash, Announce::No);
                pin_mut!(s);
                let mut v = vec![];
                while let Some(p) = s.next().await {
                    v.push(p.unwrap());
                }
                v.sort();
                let mut expected = [addr3, addr4];
                expected.sort();
                assert_eq!(&v, &expected);
            }

            {
                let s = search1.search(info_hash, Announce::SamePort);
                pin_mut!(s);
                let mut v = vec![];
                while let Some(p) = s.next().await {
                    v.push(p.unwrap());
                }
                v.sort();
                let mut expected = [addr3, addr4];
                expected.sort();
                assert_eq!(&v, &expected);
            }
            {
                let s = search1.search(info_hash, Announce::No);
                pin_mut!(s);
                let mut v = vec![];
                while let Some(p) = s.next().await {
                    v.push(p.unwrap());
                }
                v.sort();
                let mut expected = [addr1, addr3, addr4];
                expected.sort();
                assert_eq!(&v, &expected);
            }

            abort.abort();
        };

        let _ = block_on(async move { join!(replies, run) });
    }
}

// TODO self-announce
// TODO queries don't make failed pings
