use nix::ioctl_read;
use std::fs;
use std::os::unix::prelude::AsRawFd;
use std::path::Path;

use super::Vulnerability;
use crate::Error;
use crate::SwitchDevice;

const IOCTL_MAGIC: u8 = b'U';
const IOCTL_TYPE: u8 = 10;

ioctl_read!(usb_submit_urb, IOCTL_MAGIC, IOCTL_TYPE, UsbUrb);

#[derive(Debug)]
#[repr(C)]
pub struct UsbUrb {
    r#type: u8,
    endpoint: u8,
    status: i32,
    flags: u32,
    buffer: *mut u8,
    buffer_length: i32,
    actual_length: i32,
    start_frame: i32,
    stream_id: u32,
    error_count: i32,
    signr: u32,
    usercontext: *mut u8,
}

impl UsbUrb {
    fn new(buf: &mut [u8]) -> Self {
        const URB_CONTROL_REQUEST: u8 = 2;
        const ENDPOINT: u8 = 0;
        Self {
            r#type: URB_CONTROL_REQUEST,
            endpoint: ENDPOINT,
            buffer: buf.as_mut_ptr(),
            buffer_length: buf.len() as i32,
            // rest don't need to be touched
            status: 0,
            flags: 0,
            actual_length: 0,
            start_frame: 0,
            stream_id: 0,
            error_count: 0,
            signr: 0,
            usercontext: std::ptr::null_mut(),
        }
    }
}

impl Vulnerability for SwitchDevice {
    fn backend_name() -> &'static str {
        "linux"
    }

    fn trigger(&self, length: usize) -> Result<(), Error> {
        const GET_STATUS: u8 = 0x0;
        const STANDARD_REQUEST_DEVICE_TO_HOST_TO_ENDPOINT: u8 = 0x82;

        if !self.validate_environment() {
            return Err(Error::LinuxEnvError);
        }

        let file_path = format!(
            "/dev/bus/usb/{:03}/{:03}",
            self.device().device().bus_number(),
            self.device().device().address()
        );
        let file = match fs::File::options().read(true).write(true).open(&file_path) {
            Ok(file) => file,
            Err(_) => return Err(Error::LinuxEnvError),
        };
        let fd = file.as_raw_fd();

        let mut setup_packet = Vec::with_capacity(length + 8);
        setup_packet.extend(STANDARD_REQUEST_DEVICE_TO_HOST_TO_ENDPOINT.to_le_bytes());
        setup_packet.extend(GET_STATUS.to_le_bytes());
        setup_packet.extend(0u16.to_le_bytes());
        setup_packet.extend(0u16.to_le_bytes());
        setup_packet.extend((length as u16).to_le_bytes());
        setup_packet.resize(setup_packet.len() + length, b'\0');

        let mut usb_urb = UsbUrb::new(&mut setup_packet);
        unsafe { usb_submit_urb(fd, &mut usb_urb as *mut UsbUrb).unwrap() };

        // we fake a timeout to remain consistant with our other enviornments
        Err(Error::UsbError(rusb::Error::Timeout))
    }

    fn supported(&self) -> bool {
        true
    }
}

impl SwitchDevice {
    /// We can only inject giant control requests on devices that are backed
    /// by certain usb controllers-- typically, the xhci_hcd on most PCs.
    fn validate_environment(&self) -> bool {
        const SUPPORTED_USB_CONTROLLERS: [&str; 2] =
            ["pci/drivers/xhci_hcd", "platform/drivers/dwc_otg"];

        for hci_name in SUPPORTED_USB_CONTROLLERS {
            let glob = glob::glob(&format!("/sys/bus/{hci_name}/*/usb*")).unwrap();
            for path in glob {
                if self.node_matches_our_device(&path.unwrap()) {
                    return true;
                }
            }
        }

        false
    }

    /// Checks to see if the given sysfs node matches our given device.
    /// Can be used to check if an xhci_hcd controller subnode reflects a given device.
    fn node_matches_our_device(&self, path: &Path) -> bool {
        let bus_path = path.join("busnum");

        if !bus_path.exists() {
            return false;
        }

        let num = match read_num_file(&bus_path) {
            Some(num) => num,
            None => return false,
        };

        if self.device().device().bus_number() != num {
            return false;
        }

        true
    }
}

/// Reads a numeric value from a sysfs file that contains only a number.
fn read_num_file(path: &Path) -> Option<u8> {
    let data = match fs::read(path) {
        Ok(f) => f,
        Err(_) => return None,
    };
    match String::from_utf8_lossy(&data).trim().parse() {
        Ok(f) => Some(f),
        Err(_) => None,
    }
}
