use crate::{stream, Error, frame};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use log::{error, trace};
use crate::frame::{OpCode, Control, Data, CONTINUATION, TEXT, BINARY, CLOSE, PING, PONG};
use crate::ws::Msg::Close;
use byteorder::{ByteOrder, BigEndian, NetworkEndian};
use crate::deflate::Inflator;
use async_channel::{Sender, Receiver, RecvError};
use simdutf8::basic::from_utf8;

#[derive(Debug, Clone)]
// Message for communication with upstream part of the library.
pub enum Msg {
    Binary(Vec<u8>),
    Text(Vec<u8>),
    Close(u16),
    Ping(Vec<u8>),
    Pong(Vec<u8>),
    Clear(),
}


impl Msg {
    #[allow(dead_code)]
    fn is_close(&self) -> bool {
        match self {
            Msg::Close(_) => true,
            _ => false,
        }
    }

    #[allow(dead_code)]
    fn kind(&self) -> &'static str {
        match self {
            Msg::Binary(_) => "binary",
            Msg::Text(_) => "text",
            Msg::Close(_) => "close",
            Msg::Ping(_) => "ping",
            Msg::Pong(_) => "pong",
            Msg::Clear() => "clear",
        }
    }
}

// Reads bytes from the ReadHalf of the TcpStream.
// Parses bytes as WebSocket frames, validates frame rules. Converts frames to
// the Msg for communication with the application. Emits Msgs to the application
// (tx channel), and in the case of control messages directly to the other side
// of WebSocket (control_tx channel).
pub struct Reader<R> {
    stream_rx: R,
    ws_sender: Sender<Msg>,
    close_send: Sender<u8>,
    app_sender: Sender<Msg>,
    deflate_supported: bool,
    is_utf8_validate: bool,
    inflator: Inflator,
    payload: [u8; 1024],
    max_payload_len: usize,
    header_top: [u8; 2],
    header_buf: [u8; 8],
    frame: Frame,
}

impl<R> Reader<R>
    where
        R: AsyncReadExt + std::marker::Unpin + std::marker::Send + 'static,
{
    pub fn spawn(stream_rx: R, ws_sender: Sender<Msg>, app_sender: Sender<Msg>, close_send: Sender<u8>, deflate_supported: bool, is_utf8_validate: bool) {
        let mut reader = Reader {
            stream_rx,
            ws_sender,
            close_send,
            app_sender,
            deflate_supported,
            is_utf8_validate,
            inflator: Inflator::new(),
            payload: [0u8; 1024],
            max_payload_len: 1024,
            header_top: [0u8; 2],
            header_buf: [0u8; 8],
            frame: Frame::default(),
        };

        tokio::spawn(async move {
            if let Err(e) = reader.read().await {
                error!("{:?}", e);
            }
        });
    }

    async fn read(&mut self) -> Result<(), Error> {
        let mut fragment = Frame::default();
        let mut is_fragment = false;
        let status = loop {
            let res_header_top = self.stream_rx.read_exact(&mut self.header_top).await;
            match res_header_top {
                Ok(size) => {
                    if size != 2 {
                        error!("size -> {}", size);
                        break;
                    }
                }
                Err(e) => {
                    error!("header_top {:?}", e);
                    self.close_send.send(0).await;
                    break;
                }
            }

            self.frame.apply(&self.header_top);

            if self.frame.header_len > 0 {
                let x = &mut self.header_buf[0..self.frame.header_len];
                let res_header_buf = self.stream_rx.read_exact(x).await;
                match res_header_buf {
                    Ok(size) => {
                        // trace!("header_buf size -> {}", size);
                    }
                    Err(e) => {
                        error!("header_buf {:?}", e);
                        self.close_send.send(0).await;
                        break;
                    }
                }
                self.frame.payload_len = BigEndian::read_uint(x, self.frame.header_len) as usize;
            }

            // trace!("length: {:?}", self.frame.payload_len);

            if self.frame.is_mask {
                let res_masking_key = self.stream_rx.read_exact(&mut self.frame.masking_key).await;
                match res_masking_key {
                    Ok(size) => {
                        // trace!("res_masking_key size -> {}", size);
                    }
                    Err(e) => {
                        error!("masking_key {:?}", e);
                        self.close_send.send(0).await;
                        break;
                    }
                }
            }

            if self.frame.payload_len > 0 {
                let mut read_len = self.frame.payload_len;

                loop {
                    if read_len == 0 {
                        break;
                    }
                    let message = &mut self.payload[0..read_len.min(self.max_payload_len)];
                    let result = self.stream_rx.read_exact(message).await;
                    match result {
                        Ok(size) => {
                            // println!("read_exact len {}", size);
                            read_len -= size;
                            if size == 0 {
                                break;
                            }
                            self.frame.payload.extend_from_slice(&message[0..size]);
                        }
                        Err(e) => {
                            error!("payload {:?}", e);
                            self.close_send.send(0).await;
                            break;
                        }
                    }
                }
            }

            if !self.frame.validate(self.deflate_supported, is_fragment) {
                self.ws_sender.send(Msg::Close(1002)).await;
                self.close_send.send(0).await;
                break;
            }

            if self.frame.is_fragment() {
                if !self.frame.is_final && self.frame.is_data() {
                    // return Fragment::Start;
                    fragment = self.frame.clone();
                    is_fragment = true;
                    continue;
                }
                if !self.frame.is_final && self.frame.is_continuation() {
                    fragment.append(&self.frame);
                    continue;
                }
                if self.frame.is_final && self.frame.is_continuation() {
                    fragment.append(&self.frame);
                    self.frame.set_start_fragment(&fragment);
                    is_fragment = false;
                }
            }

            if self.deflate_supported && self.frame.rsv1 && self.frame.payload_len > 0 {
                let mut compressed = &mut self.frame.payload.clone();
                compressed.extend(&[0, 0, 255, 255]);

                let mut decompressed = Vec::new();

                let res = self.inflator.decompress(&compressed, &mut decompressed);

                if res.is_ok() {
                    self.frame.payload = decompressed;
                    // unsafe {
                    //     trace!("decompressed: {}", String::from_utf8_unchecked(self.frame.payload.to_vec()));
                    // }
                } else {
                    error!("{}", res.unwrap_err());
                }
            }

            if self.is_utf8_validate {
                if !self.frame.validate_payload() {
                    self.close_send.send(0).await;
                    break;
                }
            }

            // unsafe {
            //     trace!("string: {}", String::from_utf8_unchecked(self.frame.payload.to_vec()));
            // }

            // trace!("pos = {}", pos);

            // unsafe {
            //     trace!("string: {}", String::from_utf8_unchecked(out_message.to_vec()));
            // }

            if self.frame.is_text() {
                self.ws_sender.send(Msg::Text(self.frame.payload.clone())).await;
            } else if self.frame.is_binary() {
                self.ws_sender.send(Msg::Binary(self.frame.payload.clone())).await;
            } else if self.frame.is_ping() {
                self.app_sender.send(Msg::Pong(self.frame.payload.clone())).await;
                self.ws_sender.send(Msg::Ping(self.frame.payload.clone())).await;
            } else if self.frame.is_close() {
                let status = self.frame.status();
                // println!("status {}", status);
                self.ws_sender.send(Msg::Close(status)).await;
                self.close_send.send(0).await;
            }

            if self.frame.is_close() {
                break;
            }
        };
        // self.tx.send(Msg::Close(0)).await;
        trace!("reader loop closed");
        Ok(())
    }
}


#[derive(Debug, Clone, Default)]
struct Frame {
    is_final: bool,
    rsv1: bool,
    rsv2: bool,
    rsv3: bool,
    rsv: u8,
    is_mask: bool,
    op: u8,
    payload_len: usize,
    header_len: usize,
    masking_key: [u8; 4],
    payload: Vec<u8>,
}

impl Frame {
    fn apply(&mut self, header: &[u8; 2]) {
        // trace!("header: {:?}", header);
        let first = header[0];
        let second = header[1];
        // trace!("First: {:b}", first);
        // trace!("Second: {:b}", second);

        self.is_final = first & 0x80 != 0;

        // trace!("is_final {}", self.is_final);

        self.rsv1 = first & 0x40 != 0;
        self.rsv2 = first & 0x20 != 0;
        self.rsv3 = first & 0x10 != 0;
        self.rsv = (first & 0b0111_0000u8) >> 4;

        // trace!("rsv1={}, rsv2={}, rsv3={}, rsv={}", self.rsv1, self.rsv2, self.rsv3, self.rsv);

        self.is_mask = second & 0x80 != 0;
        // trace!("is_mask: {}", self.is_mask);

        self.op = first & 0x0F;
        // trace!("Opcode: {}", self.op);

        let length_byte = second & 0x7F;

        if length_byte == 126 {
            self.header_len = 2;
        } else if length_byte == 127 {
            self.header_len = 8;
        } else {
            self.header_len = 0;
            self.payload_len = u64::from(length_byte) as usize;
        }
        self.payload.clear();
    }

    fn status(&self) -> u16 {
        if self.payload_len != 2 {
            return 0;
        }
        let bytes: [u8; 2] = [self.payload[0], self.payload[1]];
        let status = u16::from_be_bytes(bytes);
        match status {
            1000 | 1001 | 1002 | 1003 | 1007 | 1008 | 1009 | 1010 | 1011 => status, /* valid status code, reply with */
            // that code
            _ => 0, // for all other reply with close frame without payload
        }
    }

    fn append(&mut self, other: &Frame) -> &Frame {
        self.payload_len = self.payload_len + other.payload_len;
        self.payload.extend_from_slice(&other.payload);
        self
    }

    fn set_start_fragment(&mut self, first: &Frame) {
        self.is_final = true;
        self.rsv1 = first.rsv1;
        self.rsv2 = first.rsv2;
        self.rsv3 = first.rsv3;
        self.rsv = first.rsv;
        self.is_mask = first.is_mask;
        self.op = first.op;
        self.header_len = first.header_len;
        self.payload_len = first.payload_len;
        self.masking_key = first.masking_key.clone();
        self.payload = first.payload.clone();
    }

    fn is_valid(&self) -> bool {
        self.is_data() || self.is_control() || self.is_continuation()
    }

    fn is_data(&self) -> bool {
        self.op == TEXT || self.op == BINARY
    }
    fn is_control(&self) -> bool {
        self.op == CLOSE || self.op == PING || self.op == PONG
    }
    fn is_continuation(&self) -> bool {
        self.op == CONTINUATION
    }

    fn is_text(&self) -> bool {
        self.op == TEXT
    }

    fn is_binary(&self) -> bool {
        self.op == BINARY
    }

    fn is_ping(&self) -> bool {
        self.op == PING
    }

    fn is_close(&self) -> bool {
        self.op == CLOSE
    }

    fn is_rsv_ok(&self, deflate_supported: bool) -> bool {
        if deflate_supported {
            return self.rsv == 0 || self.rsv == 4;
        }
        // rsv must be 0, when no extension defining RSV meaning has been negotiated
        self.rsv == 0
    }

    fn validate_payload(&self) -> bool {
        if !self.is_text() {
            return true;
        }
        let res = from_utf8(self.payload.as_slice());
        if res.is_err() {
            error!("from_utf8 {}", res.unwrap_err());
            return false;
        }
        return true;
    }

    fn validate(&self, deflate_supported: bool, in_continuation: bool) -> bool {
        if !self.is_valid() {
            error!("reserved opcode {}", self.op);
            return false;
        }
        if self.is_control() {
            if self.payload_len > 125 {
                error!("too long control frame {} > 125", self.payload_len);
                return false;
            }
            if !self.is_final {
                error!("fragmented control frame");
                return false;
            }
        } else {
            // continuation (waiting for more fragments) frames must be in order
            // start/middle.../end
            if !in_continuation && self.is_continuation() {
                error!("not in continuation");
                return false;
            }
            if in_continuation && !self.is_continuation() {
                error!("fin frame during continuation");
                return false;
            }
        }
        if !self.is_rsv_ok(deflate_supported) {
            // only bit 1 of rsv is currently used
            error!("wrong rsv");
            return false;
        }
        return true;
    }

    fn is_fragment(&self) -> bool {
        !(self.is_final && !self.is_continuation())
    }
}

// Writes bytes to the outbound tcp stream.
pub struct Writer<W> {
    stream_tx: W,
    mask_frames: bool,
    frame_writer: frame::FrameWriter,
    app_rx: Receiver<Msg>,
}

impl<W> Writer<W>
    where
        W: AsyncWriteExt + std::marker::Unpin + std::marker::Send + 'static,
{
    pub fn spawn(stream_tx: W, mask_frames: bool, app_rx: Receiver<Msg>) {
        tokio::spawn(async move {
            let mut writer = Writer {
                stream_tx,
                mask_frames,
                frame_writer: frame::FrameWriter::new(mask_frames),
                app_rx,
            };

            if let Err(e) = writer.run().await {
                error!("{:?}", e);
            }
            trace!("writer loop closed");
        });
    }

    async fn run(&mut self) -> Result<(), Error> {
        loop {
            let app = self.app_rx.recv().await;
            match app {
                Ok(msg) => {
                    let is_close = msg.is_close();
                    self.write(msg).await;
                    if is_close {
                        break;
                    }
                }
                Err(err) => {
                    // when the application writer goes out of scope
                    // self.write(Msg::Close(0)).await;
                    break;
                }
            }
        }
        Ok(())
    }

    async fn write(&mut self, msg: Msg) -> tokio::io::Result<()> {
        // println!("write: {:?}", msg);
        let raw: Vec<u8> = self.frame_writer.msg_vec(msg);
        // println!("write raw: {:?}", raw);
        // print_frame(raw.clone());

        let mut buf = raw.as_slice();

        loop {
            let len = self.stream_tx.write(buf).await.unwrap();
            // println!("write len {}", len);
            if len == buf.len() {
                break;
            }
            buf = &buf[len..];
        }
        Ok(())
    }
}


fn print_frame(buf: Vec<u8>) {
    println!("print_frame start");
    let mut pos = 0usize;
    let first = buf[0];
    let second = buf[1];
    trace!("First: {:b}", first);
    trace!("Second: {:b}", second);

    pos += 2;
    let is_final = first & 0x80 != 0;

    trace!("is_final {}", is_final);

    let rsv1 = first & 0x40 != 0;
    let rsv2 = first & 0x20 != 0;
    let rsv3 = first & 0x10 != 0;

    let op = first & 0x0F;
    let opcode = OpCode::from(op);
    trace!("Opcode: {:?} {}", opcode, op);

    let masked = second & 0x80 != 0;
    trace!("Masked: {:?}", masked);

    // trace!("self.buf: {:?}", buf);
    let length = {
        let length_byte = second & 0x7F;
        let length_length = frame::LengthFormat::for_byte(length_byte).extra_bytes();
        if length_length > 0 {
            let x = &buf[pos..pos + length_length];
            pos += length_length;
            trace!("self.buf: {:?}", x);
            BigEndian::read_uint(&x[..length_length], length_length) as usize
        } else {
            u64::from(length_byte) as usize
        }
    };

    // trace!("self.buf: {:?}", buf);
    trace!("length: {:?}", length);

    let mask = if masked {
        let mask_bytes = &buf[pos..pos + 4];
        pos += 4;
        Some(mask_bytes)
    } else {
        None
    };
    // Disallow bad opcode
    match opcode {
        OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
            error!("Encountered invalid opcode: {}", first & 0x0F);
        }
        _ => (),
    }

    let mut out_message = Vec::new();

    let message = &buf[pos..pos + length];
    out_message = message.to_vec();

    // unsafe {
    //     trace!("string: {}", String::from_utf8_unchecked(out_message.to_vec()));
    // }
    println!("print_frame end");
}


// if length > 0 {
//     if rsv1 && length > 0 {
//         let message = &mut self.buf[pos..pos + length + 4];
//
//         message[length] = 0;
//         message[length + 1] = 0;
//         message[length + 2] = 255;
//         message[length + 3] = 255;
//
//         let mut decompressed = Vec::new();
//         // compressed.put_slice(message);
//
//         // compressed.extend(&[0, 0, 255, 255]);
//
//         let res = self.inflator.decompress(&message, &mut decompressed);
//
//         if res.is_ok() {
//             out_message = decompressed;
//             // unsafe {
//             //     let st = String::from_utf8_unchecked(out_message.clone());
//             //     trace!("string: {}", st);
//             //     if st.chars().nth(0).unwrap() != '{' {
//             //         trace!("char: ...");
//             //     }
//             // }
//         } else {
//             trace!("{}", res.unwrap_err());
//         }
//     } else if !rsv1 && length > 0 {
//         let message = &mut self.buf[pos..pos + length];
//         out_message = message.to_vec();
//     } else {
//         trace!("{} {}", rsv1, length)
//     }
// }