//! Milter commands.

use crate::{frame::Frame, macros::Stage, session::State};
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use std::{
    convert::{TryFrom, TryInto},
    error::Error,
    ffi::{CStr, CString},
    fmt::{self, Display, Formatter},
    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
};

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CommandKind {
    Abort,
    BodyChunk,
    ConnInfo,
    DefMacro,
    BodyEnd,
    Helo,
    QuitNc,
    Header,
    Mail,
    Eoh,
    OptNeg,
    Quit,
    Rcpt,
    Data,
    Unknown,
}

impl CommandKind {
    pub fn as_state(&self) -> Option<State> {
        match self {
            Self::Abort => Some(State::Abort),
            Self::BodyChunk => Some(State::Body),
            Self::ConnInfo => Some(State::Conn),
            Self::DefMacro => None,
            Self::BodyEnd => Some(State::Eom),
            Self::Helo => Some(State::Helo),
            Self::QuitNc => Some(State::QuitNc),
            Self::Header => Some(State::Header),
            Self::Mail => Some(State::Mail),
            Self::Eoh => Some(State::Eoh),
            Self::OptNeg => Some(State::Opts),
            Self::Quit => Some(State::Quit),
            Self::Rcpt => Some(State::Rcpt),
            Self::Data => Some(State::Data),
            Self::Unknown => Some(State::Unknown),
        }
    }
}

impl From<CommandKind> for u8 {
    fn from(kind: CommandKind) -> Self {
        match kind {
            CommandKind::Abort => b'A',
            CommandKind::BodyChunk => b'B',
            CommandKind::ConnInfo => b'C',
            CommandKind::DefMacro => b'D',
            CommandKind::BodyEnd => b'E',
            CommandKind::Helo => b'H',
            CommandKind::QuitNc => b'K',
            CommandKind::Header => b'L',
            CommandKind::Mail => b'M',
            CommandKind::Eoh => b'N',
            CommandKind::OptNeg => b'O',
            CommandKind::Quit => b'Q',
            CommandKind::Rcpt => b'R',
            CommandKind::Data => b'T',
            CommandKind::Unknown => b'U',
        }
    }
}

impl TryFrom<u8> for CommandKind {
    type Error = CommandError;  // TODO error type

    fn try_from(value: u8) -> Result<Self, Self::Error> {
        match value {
            b'A' => Ok(Self::Abort),
            b'B' => Ok(Self::BodyChunk),
            b'C' => Ok(Self::ConnInfo),
            b'D' => Ok(Self::DefMacro),
            b'E' => Ok(Self::BodyEnd),
            b'H' => Ok(Self::Helo),
            b'K' => Ok(Self::QuitNc),
            b'L' => Ok(Self::Header),
            b'M' => Ok(Self::Mail),
            b'N' => Ok(Self::Eoh),
            b'O' => Ok(Self::OptNeg),
            b'Q' => Ok(Self::Quit),
            b'R' => Ok(Self::Rcpt),
            b'T' => Ok(Self::Data),
            b'U' => Ok(Self::Unknown),
            _ => Err(CommandError::UnknownCommandKind),
        }
    }
}

#[derive(Debug)]
pub struct CommandFrame {
    pub kind: CommandKind,
    pub buffer: Bytes,
}

impl TryFrom<Frame> for CommandFrame {
    type Error = CommandError;

    fn try_from(frame: Frame) -> Result<Self, Self::Error> {
        let kind = frame.kind.try_into()?;
        let buffer = frame.buffer;

        Ok(Self { kind, buffer })
    }
}

// TODO actually CommandParseError
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum CommandError {
    // related to buffer only:
    NotNulTerminated,
    NoCStringFound,
    NoU8Found,
    NoU16Found,
    UnknownMacroStage,
    InvalidSocketAddr,
    UnknownFamily,
    MissingOptneg,
    UnsupportedProtocolVersion,
    EmptyCString,

    // related to kind only:
    UnknownCommandKind,

    // related to stage only:
    UnknownStage,
}

impl Display for CommandError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(f, "{:?}", self)
    }
}

impl Error for CommandError {}

/// A command.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Command {
    /// The `A` command.
    Abort,
    /// The `B` command.
    BodyChunk(BodyPayload),
    /// The `C` command.
    ConnInfo(ConnInfoPayload),
    /// The `D` command.
    DefMacro(MacroPayload),
    /// The `E` command.
    BodyEnd(BodyPayload),
    /// The `H` command.
    Helo(HeloPayload),
    /// The `K` command.
    QuitNc,
    /// The `L` command.
    Header(HeaderPayload),
    /// The `M` command.
    Mail(EnvAddrPayload),
    /// The `N` command.
    Eoh,
    /// The `O` command.
    OptNeg(OptNegPayload),
    /// The `Q` command.
    Quit,
    /// The `R` command.
    Rcpt(EnvAddrPayload),
    /// The `T` command.
    Data,
    /// The `U` command.
    Unknown(UnknownPayload),
}

impl Command {
    pub fn from_frame(frame: CommandFrame) -> Result<Self, CommandError> {
        Ok(match frame.kind {
            CommandKind::Abort => Self::Abort,
            CommandKind::BodyChunk => Self::BodyChunk(BodyPayload::parse_buffer(frame.buffer)?),
            CommandKind::ConnInfo => Self::ConnInfo(ConnInfoPayload::parse_buffer(frame.buffer)?),
            CommandKind::DefMacro => Self::DefMacro(MacroPayload::parse_buffer(frame.buffer)?),
            CommandKind::BodyEnd => Self::BodyEnd(BodyPayload::parse_buffer(frame.buffer)?),
            CommandKind::Helo => Self::Helo(HeloPayload::parse_buffer(frame.buffer)?),
            CommandKind::QuitNc => Self::QuitNc,
            CommandKind::Header => Self::Header(HeaderPayload::parse_buffer(frame.buffer)?),
            CommandKind::Mail => Self::Mail(EnvAddrPayload::parse_buffer(frame.buffer)?),
            CommandKind::Eoh => Self::Eoh,
            CommandKind::OptNeg => Self::OptNeg(OptNegPayload::parse_buffer(frame.buffer)?),
            CommandKind::Quit => Self::Quit,
            CommandKind::Rcpt => Self::Rcpt(EnvAddrPayload::parse_buffer(frame.buffer)?),
            CommandKind::Data => Self::Data,
            CommandKind::Unknown => Self::Unknown(UnknownPayload::parse_buffer(frame.buffer)?),
        })
    }
}

/// A body chunk payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct BodyPayload {
    pub chunk: Vec<u8>,
}

impl BodyPayload {
    pub fn parse_buffer(buf: Bytes) -> Result<Self, CommandError> {
        let chunk = Vec::from(buf.as_ref());

        Ok(Self { chunk })
    }
}

/// A `ConnInfo` command payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct ConnInfoPayload {
    pub hostname: CString,
    pub socket_info: Option<SocketInfo>,
}

/// Socket information.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum SocketInfo {
    Inet(SocketAddr),
    Unix(CString),
}

enum Family {
    Unknown,
    Ipv4,
    Ipv6,
    Unix,
}

impl TryFrom<u8> for Family {
    type Error = CommandError;

    fn try_from(value: u8) -> Result<Self, Self::Error> {
        match value {
            b'U' => Ok(Self::Unknown),
            b'4' => Ok(Self::Ipv4),
            b'6' => Ok(Self::Ipv6),
            b'L' => Ok(Self::Unix),
            _ => Err(CommandError::UnknownFamily),
        }
    }
}

impl ConnInfoPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        let hostname = get_c_string(&mut buf)?;

        let family = get_u8(&mut buf)?.try_into()?;

        let socket_info = match family {
            Family::Unknown => None,
            Family::Ipv4 => {
                let port = get_u16(&mut buf)?;

                ensure_nul_terminated(&buf)?;

                let addr = get_c_string(&mut buf)?;
                let addr = addr
                    .into_string()
                    .map_err(|_| CommandError::InvalidSocketAddr)?
                    .parse()
                    .map_err(|_| CommandError::InvalidSocketAddr)?;

                Some(SocketInfo::Inet(SocketAddrV4::new(addr, port).into()))
            }
            Family::Ipv6 => {
                let port = get_u16(&mut buf)?;

                ensure_nul_terminated(&buf)?;

                let addr = get_c_string(&mut buf)?;
                let addr = addr
                    .into_string()
                    .map_err(|_| CommandError::InvalidSocketAddr)?
                    .parse()
                    .map_err(|_| CommandError::InvalidSocketAddr)?;

                Some(SocketInfo::Inet(SocketAddrV6::new(addr, port, 0, 0).into()))
            }
            Family::Unix => {
                let _unused = get_u16(&mut buf)?;

                ensure_nul_terminated(&buf)?;

                // TODO is CString correct for path?
                let path = get_c_string(&mut buf)?;

                Some(SocketInfo::Unix(path))
            }
        };

        Ok(Self {
            hostname,
            socket_info,
        })
    }
}

/// A `DefMacro` command payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct MacroPayload {
    pub stage: Stage,
    pub macros: Vec<CString>,  // non-empty
}

impl MacroPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        let stage = get_u8(&mut buf)?
            .try_into()
            .map_err(|_| CommandError::UnknownMacroStage)?;

        let mut macros = vec![get_c_string(&mut buf)?];
        while let Ok(s) = get_c_string(&mut buf) {
            macros.push(s);
        }

        // TODO note: libmilter uses the "macros" *as is* -- this is bad:
        // down the line it is assumed that they always come in pairs: name -- value
        // so if we have an extra string at the end, drop it and log a warning
        // but also: name and value may be invalid/empty string value

        Ok(Self { stage, macros })
    }
}

/// A `Helo` command payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct HeloPayload {
    pub hostname: CString,
}

impl HeloPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        ensure_nul_terminated(&buf)?;

        let hostname = get_c_string(&mut buf)?;

        Ok(Self { hostname })
    }
}

/*
pub trait ParseBuf<T> {
    fn parse_buf(self) -> Result<T, CommandError>;
}
impl ParseBuf<HeloPayload> for Bytes {
    fn parse_buf(mut self) -> Result<HeloPayload, CommandError> {
        ensure_nul_terminated(&self)?;

        let hostname = get_c_string(&mut self)?;

        Ok(HeloPayload { hostname })
    }
}
*/

/// A `Header` command payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct HeaderPayload {
    pub name: CString,  // non-empty
    pub value: CString,
}

impl HeaderPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        ensure_nul_terminated(&buf)?;

        // TODO slightly different from libmilter algo
        let name = get_c_string(&mut buf)?;
        if name.to_bytes().is_empty() {
            return Err(CommandError::EmptyCString);
        }

        let value = get_c_string(&mut buf)?;

        Ok(Self { name, value })
    }
}

/// An envelope address payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct EnvAddrPayload {
    pub args: Vec<CString>,  // non-empty
}

impl EnvAddrPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        let mut args = vec![get_c_string(&mut buf)?];

        while let Ok(s) = get_c_string(&mut buf) {
            args.push(s);
        }

        Ok(Self { args })
    }
}

bitflags! {
    #[derive(Default)]
    pub struct Actions: u32 {
        const ADDHDRS = 0x1;
        const CHGBODY = 0x2;
        const ADDRCPT = 0x4;
        const DELRCPT = 0x8;
        const CHGHDRS = 0x10;
        const QUARANTINE = 0x20;
        const CHGFROM = 0x40;
        const ADDRCPT_PAR = 0x80;
        const SETSYMLIST = 0x100;
    }
}

bitflags! {
    #[derive(Default)]
    pub struct ProtoOpts: u32 {
        const NOCONNECT = 0x1;
        const NOHELO = 0x2;
        const NOMAIL = 0x4;
        const NORCPT = 0x8;
        const NOBODY = 0x10;
        const NOHDRS = 0x20;
        const NOEOH = 0x40;
        const NR_HDR = 0x80;
        const NOUNKNOWN = 0x100;
        const NODATA = 0x200;
        const SKIP = 0x400;
        const RCPT_REJ = 0x800;
        const NR_CONN = 0x1000;
        const NR_HELO = 0x2000;
        const NR_MAIL = 0x4000;
        const NR_RCPT = 0x8000;
        const NR_DATA = 0x10000;
        const NR_UNKN = 0x20000;
        const NR_EOH = 0x40000;
        const NR_BODY = 0x80000;
        const HDR_LEADSPC = 0x100000;
    }
}

/// An `OptNeg` command payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct OptNegPayload {
    pub version: u32,
    pub actions: Actions,
    pub popts: ProtoOpts,
}

impl OptNegPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        if buf.remaining() < 12 {
            return Err(CommandError::MissingOptneg);
        }

        let version = buf.get_u32();

        // TODO does this belong here?:
        if version < 2 {
            return Err(CommandError::UnsupportedProtocolVersion);
        }

        let actions = Actions::from_bits_truncate(buf.get_u32());
        let popts = ProtoOpts::from_bits_truncate(buf.get_u32());

        Ok(Self {
            version,
            actions,
            popts,
        })
    }
}

/// An `Unknown` command payload.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct UnknownPayload {
    pub arg: CString,
}

impl UnknownPayload {
    pub fn parse_buffer(mut buf: Bytes) -> Result<Self, CommandError> {
        let arg = get_c_string(&mut buf)?;

        Ok(Self { arg })
    }
}

fn ensure_nul_terminated(bytes: &[u8]) -> Result<(), CommandError> {
    if !bytes.ends_with(&[0]) {
        return Err(CommandError::NotNulTerminated);
    }
    Ok(())
}

fn get_u8(buf: &mut Bytes) -> Result<u8, CommandError> {
    if !buf.has_remaining() {
        return Err(CommandError::NoU8Found);
    }
    Ok(buf.get_u8())
}

fn get_u16(buf: &mut Bytes) -> Result<u16, CommandError> {
    if buf.remaining() < 2 {
        return Err(CommandError::NoU16Found);
    }
    Ok(buf.get_u16())
}

fn get_c_string(buf: &mut Bytes) -> Result<CString, CommandError> {
    if let Some(i) = buf.iter().position(|&x| x == 0) {
        let b = buf.split_to(i + 1);
        return Ok(CStr::from_bytes_with_nul(b.as_ref()).unwrap().into());
    }
    Err(CommandError::NoCStringFound)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn header_works() {
        assert_eq!(
            HeaderPayload::parse_buffer(Bytes::from_static(b"name\0value\0")),
            Ok(HeaderPayload {
                name: CString::new("name").unwrap(),
                value: CString::new("value").unwrap(),
            })
        );
        assert!(HeaderPayload::parse_buffer(Bytes::new()).is_err());
        assert!(HeaderPayload::parse_buffer(Bytes::from_static(b"name")).is_err());
    }

    #[test]
    fn helo_works() {
        assert_eq!(
            HeloPayload::parse_buffer(Bytes::from_static(b"hello\0")),
            Ok(HeloPayload { hostname: CString::new("hello").unwrap() })
        );
        assert!(HeloPayload::parse_buffer(Bytes::new()).is_err());
        assert!(HeloPayload::parse_buffer(Bytes::from_static(b"hello")).is_err());

        // undocumented:
        assert!(HeloPayload::parse_buffer(Bytes::from_static(b"hello\0excess")).is_err());
        assert_eq!(
            HeloPayload::parse_buffer(Bytes::from_static(b"hello\0excess\0")),
            Ok(HeloPayload { hostname: CString::new("hello").unwrap() })
        );
    }
}
