use crate::job::{ArcJob, StackJob};
use crate::latch::{CountLatch, LatchRef};
use crate::registry::{Registry, WorkerThread};
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;

mod test;









/// [m]: struct.ThreadPool.html#method.broadcast
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
where
    OP: Fn(BroadcastContext<'_>) -> R + Sync,
    R: Send,
{
    
    unsafe { broadcast_in(op, &Registry::current()) }
}








/// [m]: struct.ThreadPool.html#method.spawn_broadcast
pub fn spawn_broadcast<OP>(op: OP)
where
    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
    
    unsafe { spawn_broadcast_in(op, &Registry::current()) }
}


pub struct BroadcastContext<'a> {
    worker: &'a WorkerThread,

    
    _marker: PhantomData<&'a mut dyn Fn()>,
}

impl<'a> BroadcastContext<'a> {
    pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
        let worker_thread = WorkerThread::current();
        assert!(!worker_thread.is_null());
        f(BroadcastContext {
            worker: unsafe { &*worker_thread },
            _marker: PhantomData,
        })
    }

    
    #[inline]
    pub fn index(&self) -> usize {
        self.worker.index()
    }

    
    
    /// # Future compatibility note
    
    
    
    
    #[inline]
    pub fn num_threads(&self) -> usize {
        self.worker.registry().num_threads()
    }
}

impl<'a> fmt::Debug for BroadcastContext<'a> {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt.debug_struct("BroadcastContext")
            .field("index", &self.index())
            .field("num_threads", &self.num_threads())
            .field("pool_id", &self.worker.registry().id())
            .finish()
    }
}







pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
where
    OP: Fn(BroadcastContext<'_>) -> R + Sync,
    R: Send,
{
    let f = move |injected: bool| {
        debug_assert!(injected);
        BroadcastContext::with(&op)
    };

    let n_threads = registry.num_threads();
    let current_thread = WorkerThread::current().as_ref();
    let latch = CountLatch::with_count(n_threads, current_thread);
    let jobs: Vec<_> = (0..n_threads)
        .map(|_| StackJob::new(&f, LatchRef::new(&latch)))
        .collect();
    let job_refs = jobs.iter().map(|job| job.as_job_ref());

    registry.inject_broadcast(job_refs);

    
    latch.wait(current_thread);
    jobs.into_iter().map(|job| job.into_result()).collect()
}







pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
where
    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
    let job = ArcJob::new({
        let registry = Arc::clone(registry);
        move || {
            registry.catch_unwind(|| BroadcastContext::with(&op));
            registry.terminate(); 
        }
    });

    let n_threads = registry.num_threads();
    let job_refs = (0..n_threads).map(|_| {
        
        
        registry.increment_terminate_count();

        ArcJob::as_static_job_ref(&job)
    });

    registry.inject_broadcast(job_refs);
}
