use self::mdns::MdnsWrapper;
use futures::FutureExt;
use futures_timer::Delay;
use ip_network::IpNetwork;
use libp2p::{
    core::{
        connection::{ConnectionId, ListenerId},
        ConnectedPoint,
    },
    kad::{handler::KademliaHandlerProto, store::MemoryStore, Kademlia, KademliaEvent, QueryId},
    mdns::MdnsEvent,
    multiaddr::Protocol,
    swarm::{
        DialError, IntoProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction, PollParameters,
        ProtocolsHandler,
    },
    Multiaddr, PeerId,
};
use std::{
    collections::VecDeque,
    io,
    task::{Context, Poll},
    time::Duration,
};
use tracing::trace;
mod discovery_config;
mod mdns;
pub use discovery_config::DiscoveryConfig;

const SIXTY_SECONDS: Duration = Duration::from_secs(60);

/// Event generated by the `DiscoveryBehaviour`.
#[derive(Debug)]
pub enum DiscoveryEvent {
    /// Notify the swarm of an UnroutablePeer
    UnroutablePeer(PeerId),

    /// Notify the swarm of a connected peer
    Connected(PeerId, Vec<Multiaddr>),

    /// Notify the swarm of a disconnected peer
    Disconnected(PeerId),
}

/// NetworkBehavior for discovery of nodes
pub struct DiscoveryBehaviour {
    /// List of bootstrap nodes and their addresses
    bootstrap_nodes: Vec<(PeerId, Multiaddr)>,

    /// Events to report to the swarm
    events: VecDeque<DiscoveryEvent>,

    /// For discovery on local network, optionally available
    mdns: MdnsWrapper,

    /// Kademlia with MemoryStore
    kademlia: Kademlia<MemoryStore>,

    /// If enabled, the Stream that will fire after the delay expires,
    /// starting new random walk
    next_kad_random_walk: Option<Delay>,

    /// The Duration for the next random walk, after the current one ends
    duration_to_next_kad: Duration,

    /// Track the count of peers connected
    connected_peers_count: u64,

    /// Maximum amount of allowed peers
    max_peers_connected: u64,

    /// If false, `addresses_of_peer` won't return any private IPv4/IPv6 address,
    /// except for the ones stored in `bootstrap_nodes`.
    allow_private_addresses: bool,
}

impl DiscoveryBehaviour {
    /// Adds a known listen address of a peer participating in the DHT to the routing table.
    pub fn add_address(&mut self, peer_id: &PeerId, address: Multiaddr) {
        self.kademlia.add_address(peer_id, address);
    }
}

impl NetworkBehaviour for DiscoveryBehaviour {
    type ProtocolsHandler = KademliaHandlerProto<QueryId>;
    type OutEvent = DiscoveryEvent;

    // Initializes new handler on a new opened connection
    fn new_handler(&mut self) -> Self::ProtocolsHandler {
        // in our case we just return KademliaHandlerProto
        self.kademlia.new_handler()
    }

    // receives events from KademliaHandler and pass it down to kademlia
    fn inject_event(
        &mut self,
        peer_id: PeerId,
        connection: ConnectionId,
        event: <<Self::ProtocolsHandler as IntoProtocolsHandler>::Handler as ProtocolsHandler>::OutEvent,
    ) {
        self.kademlia.inject_event(peer_id, connection, event);
    }

    // gets polled by the swarm
    fn poll(
        &mut self,
        cx: &mut Context<'_>,
        params: &mut impl PollParameters,
    ) -> Poll<NetworkBehaviourAction<Self::OutEvent, Self::ProtocolsHandler>> {
        if let Some(next_event) = self.events.pop_front() {
            return Poll::Ready(NetworkBehaviourAction::GenerateEvent(next_event));
        }

        // if random walk is enabled poll the stream that will fire when random walk is scheduled
        if let Some(next_kad_random_query) = self.next_kad_random_walk.as_mut() {
            while next_kad_random_query.poll_unpin(cx).is_ready() {
                if self.connected_peers_count < self.max_peers_connected {
                    let random_peer_id = PeerId::random();
                    self.kademlia.get_closest_peers(random_peer_id);
                }

                *next_kad_random_query = Delay::new(self.duration_to_next_kad);
                // duration to next random walk should either be exponentially bigger than the previous
                // or at max 60 seconds
                self.duration_to_next_kad =
                    std::cmp::min(self.duration_to_next_kad * 2, SIXTY_SECONDS);
            }
        }

        // poll Kademlia behaviour
        while let Poll::Ready(kad_action) = self.kademlia.poll(cx, params) {
            match kad_action {
                NetworkBehaviourAction::GenerateEvent(KademliaEvent::UnroutablePeer { peer }) => {
                    return Poll::Ready(NetworkBehaviourAction::GenerateEvent(
                        DiscoveryEvent::UnroutablePeer(peer),
                    ))
                }

                NetworkBehaviourAction::Dial { handler, opts } => {
                    return Poll::Ready(NetworkBehaviourAction::Dial { handler, opts });
                }
                NetworkBehaviourAction::CloseConnection {
                    peer_id,
                    connection,
                } => {
                    return Poll::Ready(NetworkBehaviourAction::CloseConnection {
                        peer_id,
                        connection,
                    });
                }
                NetworkBehaviourAction::NotifyHandler {
                    peer_id,
                    handler,
                    event,
                } => {
                    return Poll::Ready(NetworkBehaviourAction::NotifyHandler {
                        peer_id,
                        handler,
                        event,
                    })
                }
                NetworkBehaviourAction::ReportObservedAddr { address, score } => {
                    return Poll::Ready(NetworkBehaviourAction::ReportObservedAddr {
                        address,
                        score,
                    })
                }
                _ => {}
            }
        }

        while let Poll::Ready(mdns_event) = self.mdns.poll(cx, params) {
            match mdns_event {
                NetworkBehaviourAction::GenerateEvent(MdnsEvent::Discovered(list)) => {
                    // inform kademlia of newly discovered local peers
                    // only if there aren't enough peers already connected
                    if self.connected_peers_count < self.max_peers_connected {
                        for (peer_id, multiaddr) in list {
                            self.kademlia.add_address(&peer_id, multiaddr);
                        }
                    }
                }
                NetworkBehaviourAction::ReportObservedAddr { address, score } => {
                    return Poll::Ready(NetworkBehaviourAction::ReportObservedAddr {
                        address,
                        score,
                    })
                }
                NetworkBehaviourAction::CloseConnection {
                    peer_id,
                    connection,
                } => {
                    return Poll::Ready(NetworkBehaviourAction::CloseConnection {
                        peer_id,
                        connection,
                    })
                }
                _ => {}
            }
        }

        if let Some(next_event) = self.events.pop_front() {
            return Poll::Ready(NetworkBehaviourAction::GenerateEvent(next_event));
        }

        Poll::Pending
    }

    /// return list of known addresses for a given peer
    fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec<Multiaddr> {
        let mut list = self
            .bootstrap_nodes
            .iter()
            .filter_map(|(current_peer_id, multiaddr)| {
                if current_peer_id == peer_id {
                    Some(multiaddr.clone())
                } else {
                    None
                }
            })
            .collect::<Vec<_>>();

        {
            let mut list_to_filter = Vec::new();

            list_to_filter.extend(self.kademlia.addresses_of_peer(peer_id));
            list_to_filter.extend(self.mdns.addresses_of_peer(peer_id));

            // filter private addresses
            // nodes could potentially report addresses in the private network
            // which are not actually part of the network
            if !self.allow_private_addresses {
                list_to_filter.retain(|addr| match addr.iter().next() {
                    Some(Protocol::Ip4(addr)) if !IpNetwork::from(addr).is_global() => false,
                    Some(Protocol::Ip6(addr)) if !IpNetwork::from(addr).is_global() => false,
                    _ => true,
                });
            }

            list.extend(list_to_filter);
        }

        trace!("Addresses of {:?}: {:?}", peer_id, list);

        list
    }

    fn inject_connection_established(
        &mut self,
        peer_id: &PeerId,
        connection_id: &ConnectionId,
        endpoint: &ConnectedPoint,
        failed_addresses: Option<&Vec<Multiaddr>>,
    ) {
        self.connected_peers_count += 1;

        self.kademlia.inject_connection_established(
            peer_id,
            connection_id,
            endpoint,
            failed_addresses,
        )
    }

    fn inject_connected(&mut self, peer_id: &PeerId) {
        let addresses = self.addresses_of_peer(peer_id);

        self.events
            .push_back(DiscoveryEvent::Connected(*peer_id, addresses));
        self.kademlia.inject_connected(peer_id);
        trace!("Connected to a peer {:?}", peer_id);
    }

    fn inject_connection_closed(
        &mut self,
        _peer_id: &PeerId,
        _connection_id: &ConnectionId,
        _connection_point: &ConnectedPoint,
        _handler: <Self::ProtocolsHandler as IntoProtocolsHandler>::Handler,
    ) {
        self.connected_peers_count -= 1;
        // no need to pass it to kademlia.inject_connection_closed() since it does nothing
    }

    fn inject_disconnected(&mut self, peer_id: &PeerId) {
        self.events
            .push_back(DiscoveryEvent::Disconnected(*peer_id));

        self.kademlia.inject_disconnected(peer_id)
    }

    fn inject_new_external_addr(&mut self, addr: &Multiaddr) {
        self.kademlia.inject_new_external_addr(addr)
    }

    fn inject_expired_listen_addr(&mut self, id: ListenerId, addr: &Multiaddr) {
        self.kademlia.inject_expired_listen_addr(id, addr);
    }

    fn inject_dial_failure(
        &mut self,
        peer_id: Option<PeerId>,
        handler: Self::ProtocolsHandler,
        err: &DialError,
    ) {
        self.kademlia.inject_dial_failure(peer_id, handler, err)
    }

    fn inject_new_listen_addr(&mut self, id: ListenerId, addr: &Multiaddr) {
        self.kademlia.inject_new_listen_addr(id, addr)
    }

    fn inject_listener_error(&mut self, id: ListenerId, err: &(dyn std::error::Error + 'static)) {
        self.kademlia.inject_listener_error(id, err)
    }

    fn inject_listener_closed(&mut self, id: ListenerId, reason: Result<(), &io::Error>) {
        self.kademlia.inject_listener_closed(id, reason)
    }
}

#[cfg(test)]
mod tests {
    use super::{DiscoveryBehaviour, DiscoveryConfig};
    use crate::discovery::DiscoveryEvent;
    use futures::{future::poll_fn, StreamExt};
    use libp2p::{
        core, identity::Keypair, multiaddr::Protocol, noise, swarm::SwarmEvent, yamux, Multiaddr,
        PeerId, Swarm, Transport,
    };
    use std::{
        collections::{HashSet, VecDeque},
        task::Poll,
        time::Duration,
    };

    /// helper function for building Discovery Behaviour for testing
    fn build_fuel_discovery(
        bootstrap_nodes: Vec<(PeerId, Multiaddr)>,
    ) -> (Swarm<DiscoveryBehaviour>, Multiaddr, PeerId) {
        let keypair = Keypair::generate_secp256k1();
        let public_key = keypair.public();

        let noise_keys = noise::Keypair::<noise::X25519Spec>::new()
            .into_authentic(&keypair)
            .unwrap();

        let transport = core::transport::MemoryTransport
            .upgrade(core::upgrade::Version::V1)
            .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated())
            .multiplex(yamux::YamuxConfig::default())
            .boxed();

        let behaviour = {
            let mut config =
                DiscoveryConfig::new(keypair.public().to_peer_id(), "test_network".into());
            config
                .discovery_limit(50)
                .with_bootstrap_nodes(bootstrap_nodes)
                .set_connection_idle_timeout(Duration::from_secs(120))
                .enable_random_walk(true);

            config.finish()
        };

        let listen_addr: Multiaddr = Protocol::Memory(rand::random::<u64>()).into();
        let mut swarm = Swarm::new(transport, behaviour, keypair.public().to_peer_id());

        swarm
            .listen_on(listen_addr.clone())
            .expect("swarm should start listening");

        (swarm, listen_addr, PeerId::from_public_key(&public_key))
    }

    // builds 25 discovery swarms,
    // initially, only connects first_swarm to the rest of the swarms
    // after that each swarm uses kademlia to discover other swarms
    // test completes after all swarms have connected to each other
    #[tokio::test]
    async fn discovery_works() {
        env_logger::init();
        // Number of peers in the network
        let num_of_swarms = 25;
        let (first_swarm, first_peer_addr, first_peer_id) = build_fuel_discovery(vec![]);

        let mut discovery_swarms = (0..num_of_swarms - 1)
            .map(|_| build_fuel_discovery(vec![(first_peer_id, first_peer_addr.clone())]))
            .collect::<VecDeque<_>>();

        discovery_swarms.push_front((first_swarm, first_peer_addr, first_peer_id));

        // HashSet of swarms to discover for each swarm
        let mut left_to_discover = (0..discovery_swarms.len())
            .map(|current_index| {
                (0..discovery_swarms.len())
                    .skip(1) // first_swarm is already connected
                    .filter_map(|swarm_index| {
                        // filter your self
                        if swarm_index != current_index {
                            // get the PeerId
                            Some(*Swarm::local_peer_id(&discovery_swarms[swarm_index].0))
                        } else {
                            None
                        }
                    })
                    .collect::<HashSet<_>>()
            })
            .collect::<Vec<_>>();

        let test_future = poll_fn(move |cx| {
            'polling: loop {
                for swarm_index in 0..discovery_swarms.len() {
                    if let Poll::Ready(Some(event)) =
                        discovery_swarms[swarm_index].0.poll_next_unpin(cx)
                    {
                        if let SwarmEvent::Behaviour(discovery_event) = event {
                            match discovery_event {
                                // if peer has connected - remove it from the set
                                DiscoveryEvent::Connected(connected_peer, _) => {
                                    left_to_discover[swarm_index].remove(&connected_peer);
                                }
                                DiscoveryEvent::UnroutablePeer(unroutable_peer_id) => {
                                    // kademlia discovered a peer but does not have it's address
                                    // we simulate Identify happening and provide the address
                                    let unroutable_peer_addr = discovery_swarms
                                        .iter()
                                        .find_map(|(_, next_addr, next_peer_id)| {
                                            // identify the peer
                                            if next_peer_id == &unroutable_peer_id {
                                                // and return it's address
                                                Some(next_addr.clone())
                                            } else {
                                                None
                                            }
                                        })
                                        .unwrap();

                                    // kademlia must be informed of a peer's address before
                                    // adding it to the routing table
                                    discovery_swarms[swarm_index]
                                        .0
                                        .behaviour_mut()
                                        .kademlia
                                        .add_address(
                                            &unroutable_peer_id,
                                            unroutable_peer_addr.clone(),
                                        );
                                }
                                DiscoveryEvent::Disconnected(peer_id) => {
                                    panic!("PeerId {:?} disconnected", peer_id);
                                }
                            }
                        }
                        continue 'polling;
                    }
                }
                break;
            }

            // if there are no swarms left to discover we are done with the discovery
            if left_to_discover.iter().all(|l| l.is_empty()) {
                // we are done!
                Poll::Ready(())
            } else {
                // keep polling Discovery Behaviour
                Poll::Pending
            }
        });

        test_future.await;
    }
}
