use std::marker::PhantomData;
use std::thread;

use crate::entities::{CtxWrapper, MapReduceError, ThreadError, Worker};
use crate::traits::{SCFunMF, SCFunRF, Sss};

const DEFAULT_CONTEXT: () = ();

// type BoxMF<Ctx, Req, Res> = Box<dyn Fn(&Ctx, usize, &Req) -> Res>;
type BoxRF<Ctx, Res> = Box<dyn Fn(&Ctx, Res, Res) -> Res>;

pub struct Manager<'a, Ctx> {
    context: &'a Ctx,
}

pub struct MapManager<'a, Ctx, Req, Resp: Sss, MF: Fn(&Ctx, usize, &Req) -> Resp> {
    phantom_req: PhantomData<Req>,
    manager: Manager<'a, Ctx>,
    fun: MF,
}

pub struct ReduceManager<
    'a,
    Ctx,
    Req,
    Resp: Sss,
    MF: Fn(&Ctx, usize, &Req) -> Resp,
    RF: Fn(&Ctx, Resp, Resp) -> Resp,
> {
    map_manager: MapManager<'a, Ctx, Req, Resp, MF>,
    fun: RF,
}

type ReduceRManager<'a, Ctx, Req, ResOk, ResErr, MF> =
    ReduceManager<'a, Ctx, Req, Result<ResOk, ResErr>, MF, BoxRF<Ctx, Result<ResOk, ResErr>>>;

pub struct FullManager<
    'a,
    Ctx,
    Req,
    Resp: Sss,
    MF: Fn(&Ctx, usize, &Req) -> Resp,
    RF: Fn(&Ctx, Resp, Resp) -> Resp,
> {
    reduce_manager: ReduceManager<'a, Ctx, Req, Resp, MF, RF>,
    default_value: Resp,
}

pub fn manager() -> Manager<'static, ()> {
    Manager {
        context: &DEFAULT_CONTEXT,
    }
}

impl<'a, Ctx> Manager<'a, Ctx> {
    pub fn context<NewCtx>(self, context: &NewCtx) -> Manager<NewCtx> {
        Manager { context }
    }

    pub fn map<Req, Resp: Sss, MF: Fn(&Ctx, usize, &Req) -> Resp>(
        self,
        fun: MF,
    ) -> MapManager<'a, Ctx, Req, Resp, MF> {
        MapManager {
            manager: self,
            fun,
            phantom_req: PhantomData::default(),
        }
    }
}

impl<'a, Ctx, Req, Resp: Sss, MF: SCFunMF<Ctx, Req, Resp>> MapManager<'a, Ctx, Req, Resp, MF> {
    pub fn reduce<RF: SCFunRF<Ctx, Resp>>(
        self,
        fun: RF,
    ) -> ReduceManager<'a, Ctx, Req, Resp, MF, RF> {
        ReduceManager {
            map_manager: self,
            fun,
        }
    }
}

#[inline]
fn get_ok<O, E>(res: Result<O, E>) -> O {
    match res {
        Ok(val) => val,
        Err(_) => panic!("err in reduce function"),
    }
}

impl<'a, Ctx, Req, ROk: Sss, RErr: Sss, MF: SCFunMF<Ctx, Req, Result<ROk, RErr>>>
    MapManager<'a, Ctx, Req, Result<ROk, RErr>, MF>
{
    pub fn reduce_result<RF: 'static + Fn(&Ctx, ROk, ROk) -> Result<ROk, RErr>>(
        self,
        fun: RF,
    ) -> ReduceRManager<'a, Ctx, Req, ROk, RErr, MF> {
        let new_fun: BoxRF<Ctx, Result<ROk, RErr>> =
            Box::new(move |ctx, a, b| fun(ctx, get_ok(a), get_ok(b)));
        ReduceManager {
            map_manager: self,
            fun: new_fun,
        }
    }
}

impl<
        'a,
        Ctx,
        Req,
        Resp: Sss,
        MF: SCFunMF<Ctx, Req, Resp>,
        RF: 'static + Fn(&Ctx, Resp, Resp) -> Resp,
    > ReduceManager<'a, Ctx, Req, Resp, MF, RF>
{
    pub fn default(self, default_value: Resp) -> FullManager<'a, Ctx, Req, Resp, MF, RF> {
        FullManager {
            reduce_manager: self,
            default_value,
        }
    }
}

impl<
        'a,
        Ctx,
        Req,
        Resp: Sss + Default,
        MF: SCFunMF<Ctx, Req, Resp>,
        RF: 'static + Fn(&Ctx, Resp, Resp) -> Resp,
    > ReduceManager<'a, Ctx, Req, Resp, MF, RF>
{
    fn into_full(self) -> FullManager<'a, Ctx, Req, Resp, MF, RF> {
        FullManager {
            reduce_manager: self,
            default_value: Resp::default(),
        }
    }

    pub fn run(self, chunks: &[Req]) -> Result<Resp, ThreadError> {
        self.into_full().run(chunks)
    }
}

impl<
        'a,
        Ctx,
        Req,
        ROk: Sss + Default,
        RErr: Sss,
        MF: SCFunMF<Ctx, Req, Result<ROk, RErr>>,
        RF: 'static + Fn(&Ctx, Result<ROk, RErr>, Result<ROk, RErr>) -> Result<ROk, RErr>,
    > ReduceManager<'a, Ctx, Req, Result<ROk, RErr>, MF, RF>
{
    pub fn run_result(self, chunks: &[Req]) -> Result<ROk, MapReduceError<RErr>> {
        self.default(Ok(ROk::default())).run_result(chunks)
    }
}

impl<
        'a,
        Ctx,
        Req,
        Resp: Sss,
        MF: SCFunMF<Ctx, Req, Resp>,
        RF: 'static + Fn(&Ctx, Resp, Resp) -> Resp,
    > FullManager<'a, Ctx, Req, Resp, MF, RF>
{
    fn make_workers(&self, chunks: &[Req]) -> Vec<Worker<Resp>> {
        let fun = CtxWrapper::new(&self.reduce_manager.map_manager.fun);
        let ctx = CtxWrapper::new(self.reduce_manager.map_manager.manager.context);
        let mut workers = Vec::new();
        for (id, chunk) in chunks.iter().enumerate() {
            let request = CtxWrapper::new(chunk);
            let handler: thread::JoinHandle<Resp> =
                thread::spawn(move || fun.get::<MF>()(ctx.get::<Ctx>(), id, request.get::<Req>()));
            let worker = Worker {
                id,
                thread: Box::new(handler),
            };
            workers.push(worker);
        }
        workers
    }

    pub fn run(self, chunks: &[Req]) -> Result<Resp, ThreadError> {
        let mut workers = self.make_workers(chunks);
        let mut result = self.default_value;
        let ctx = self.reduce_manager.map_manager.manager.context;
        let mut failed = None;
        workers.reverse();
        for _ in 0..workers.len() {
            let worker = workers.pop().unwrap();
            let data = match worker.thread.join() {
                Ok(val) => val,
                Err(err) => {
                    failed = Some(err);
                    break;
                }
            };
            result = (self.reduce_manager.fun)(ctx, result, data)
        }
        if let Some(err) = failed {
            for worker in workers {
                let _ = worker.thread.join();
            }
            Err(err)
        } else {
            Ok(result)
        }
    }
}

impl<
        'a,
        Ctx,
        Req,
        ROk: Sss,
        RErr: Sss,
        MF: 'static + Fn(&Ctx, usize, &Req) -> Result<ROk, RErr>,
        RF: 'static + Fn(&Ctx, Result<ROk, RErr>, Result<ROk, RErr>) -> Result<ROk, RErr>,
    > FullManager<'a, Ctx, Req, Result<ROk, RErr>, MF, RF>
{
    pub fn run_result(self, chunks: &[Req]) -> Result<ROk, MapReduceError<RErr>> {
        let mut workers = self.make_workers(chunks);
        let ctx = self.reduce_manager.map_manager.manager.context;
        let mut result = self.default_value;
        let mut failed = None;
        workers.reverse();
        for _ in 0..workers.len() {
            let worker = workers.pop().unwrap();
            let worker_res = match worker.thread.join() {
                Ok(val) => val,
                Err(err) => {
                    failed = Some(err);
                    break;
                }
            };
            if worker_res.is_ok() {
                result = (self.reduce_manager.fun)(ctx, result, worker_res);
                if result.is_err() {
                    break;
                }
            } else {
                result = worker_res;
                // stop reducing results
                break;
            }
        }
        // skip rest results
        for worker in workers {
            let _ = worker.thread.join();
        }
        match (failed, result) {
            (Some(err), _) => Err(MapReduceError::ThreadFailed(err)),
            (_, Err(err)) => Err(MapReduceError::Custom(err)),
            (_, Ok(val)) => Ok(val),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::str::FromStr;
    use std::time::Duration;

    #[test]
    fn test_manager() {
        let res = manager()
            .map(|_, _, n: &usize| n.to_string())
            .reduce(|_, a: String, b: String| format!("{}{}", a, b))
            .run(&[1, 2, 3])
            .unwrap();
        assert_eq!(res, "123");
    }

    #[test]
    fn test_with_context() {
        let var: usize = 5;
        let res = manager()
            .context(&var)
            .map(|ctx: &usize, _, n: &usize| (n + ctx).to_string())
            .reduce(|_, a: String, b: String| format!("{}{}", a, b))
            .run(&[1, 2, 3])
            .unwrap();
        assert_eq!(res, "678");
    }

    #[test]
    fn test_complex_types() {
        let arg = Box::new("abc");
        let args: Vec<Box<dyn Fn(&str) -> String>> = vec![
            Box::new(|s: &str| format!("{}a", s)),
            Box::new(|s: &str| format!("{}b", s)),
            Box::new(|s: &str| format!("{}c", s)),
        ];
        let res = manager()
            .context(&*arg)
            .map(|ctx, _, f: &Box<dyn Fn(&str) -> String>| f(ctx))
            .reduce(|_, a: String, b: String| format!("({}{})", a, b))
            .run(&args)
            .unwrap();
        assert_eq!(res, "(((abca)abcb)abcc)")
    }

    #[test]
    fn test_slices() {
        let params: &[&[usize]] = &[&[1, 2, 3], &[4, 5]];
        let res = manager()
            .map(|_, _, s: &&[usize]| s.len())
            .reduce(|_, a, b| a + b)
            .run(params)
            .unwrap();
        assert_eq!(res, 5);
    }

    #[test]
    fn test_result_ok() {
        let res = manager()
            .map(|_, _, s: &&str| u8::from_str(s))
            .reduce_result(|_, a: u8, b: u8| Ok(a + b))
            .run_result(&["1", "2", "3"]);
        match res {
            Ok(6) => (),
            _ => panic!(),
        }
    }

    #[test]
    fn test_result_err_map() {
        let res = manager()
            .map(|_, _, s: &&str| u8::from_str(s))
            .reduce_result(|_, a: u8, b: u8| Ok(a + b))
            .run_result(&["1", "arr", "3"]);
        let err = match res {
            Err(MapReduceError::Custom(err)) => err.to_string(),
            _ => panic!(),
        };
        assert_eq!(err, "invalid digit found in string");
    }

    #[test]
    fn test_result_err_reduce() {
        let res = manager()
            .map(|_, _, s: &u8| Ok(*s))
            .reduce_result(|_, _: u8, _: u8| Err(()))
            .run_result(&[1, 2, 3]);
        match res {
            Err(MapReduceError::Custom(())) => (),
            _ => panic!(),
        }
    }

    #[test]
    fn test_sleeps_and_errors() {
        let res = manager()
            .map(|_, _, s: &u8| {
                thread::sleep(Duration::from_secs(*s as u64));
                if *s <= 1 {
                    Ok(*s)
                } else {
                    Err(())
                }
            })
            .reduce_result(|_, a, b| Ok(a + b))
            .run_result(&[1, 1, 2, 3, 3]);
        match res {
            Err(MapReduceError::Custom(())) => (),
            _ => panic!(),
        }
    }

    #[test]
    fn test_thread_id() {
        let res: String = manager()
            .map(|_, thread, val: &char| format!("{}:{}", thread, val))
            .reduce(|_, a, b| format!("[{}{}]", a, b))
            .run(&['a', 'b', 'c'])
            .unwrap();
        assert_eq!(res, "[[[0:a]1:b]2:c]")
    }

    struct StructWrapper {
        data: usize,
    }

    fn func(context: &str, data: &[StructWrapper]) -> usize {
        let res = manager()
            .context(&context)
            .map(|_, _, val: &&[StructWrapper]| {
                let mut sum = 0;
                for s in *val {
                    sum += s.data
                }
                sum
            })
            .reduce(|_, a, b| a + b)
            .run(&[&data[0..2], &data[2..4]]);
        match res {
            Ok(val) => val,
            _ => panic!(),
        }
    }

    #[test]
    fn test_map_per_chunks() {
        let wrapped = vec![
            StructWrapper { data: 1 },
            StructWrapper { data: 2 },
            StructWrapper { data: 3 },
            StructWrapper { data: 4 },
        ];
        assert_eq!(func("hello", &wrapped), 10);
        assert_eq!(wrapped[0].data, 1);
    }
}
