# Treadpool项目中应用
# 配置
项目中使用线程池,最好是重写threadpoolExecutor,而非直接用Executors.new。因为重写的自定义 程度较高,出现问题容易定位。同时,也可记录子线程的traceId等
@Bean(name = "siteSignRateThreadPoolExecutor")
public ThreadPoolExecutor siteSignRateThreadPoolExecutor() {
return new ThreadPoolExecutor(10,
50,
60,
TimeUnit.SECONDS,
new LinkedBlockingDeque<>(QUEUE_SIZE),
new ThreadFactory() {
final AtomicInteger aInt = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
return new Thread(r, "thread-sign-rate-" + aInt.getAndIncrement());
}
},
new RejectedExecutionHandler() {
@Override
public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
log.error("线程池队列已满,当前线程池信息:{}", executor.toString());
throw new ThreadPoolException("当前系统繁忙,请稍后再试");
}
}) {
@Override
public void execute(Runnable command) {
super.execute(wrap(command, MDC.getCopyOfContextMap()));
}
@Override
public Future<?> submit(Runnable task) {
return super.submit(wrap(task, MDC.getCopyOfContextMap()));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return super.submit(wrap(task, MDC.getCopyOfContextMap()), result);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return super.submit(wrap(task, MDC.getCopyOfContextMap()));
}
};
}
重新包装任务,让子线程带有traceId,可进行追踪
public static <T> Callable<T> wrap(final Callable<T> callable, final Map<String, String> context) {
return () -> {
if (context == null) {
MDC.clear();
} else {
MDC.setContextMap(context);
}
setTraceIdIfAbsent();
try {
return callable.call();
} catch (ThreadPoolException e) {
log.error("线程池已满.", e);
throw new ThreadPoolException("当前系统繁忙,请稍后再试");
} catch (Exception e) {
log.error("子线程计算错误.", e);
} finally {
MDC.clear();
}
return null;
};
}
public static Runnable wrap(final Runnable runnable, final Map<String, String> context) {
return () -> {
if (context == null) {
MDC.clear();
} else {
MDC.setContextMap(context);
}
setTraceIdIfAbsent();
try {
runnable.run();
} catch (ThreadPoolException e) {
log.error("线程池已满.", e);
throw new ThreadPoolException("当前系统繁忙,请稍后再试");
} catch (Exception e) {
log.error("子线程计算错误.", e);
} finally {
MDC.clear();
}
};
}
public static void setTraceIdIfAbsent() {
if (MDC.get("traceId") == null) {
MDC.put("traceId", UUID.randomUUID().toString());
}
}
# 等待sumbit结果
如代码所示,可以获取到future集合,然后阻塞结果即可获取到最终值
@Override
public List<SiteSignRateVO> listSiteRate(List<Site> reportSiteList, SiteSignRateParam param) {
List<SiteSignRateVO> voList = new ArrayList<>();
// 由于有分页,所以循环访问db次数可控
List<Future<SiteSignRateVO>> futureList = new ArrayList<>();
for (Site site : reportSiteList) {
SiteSignRateVO vo;
vo = SiteSignRateVO.init(site.getSiteId(), site.getSiteName());
Future<SiteSignRateVO> future = siteSignRateThreadPoolExecutor.submit(new EachSiteSignRateTask(param, vo));
futureList.add(future);
}
for (Future<SiteSignRateVO> future : futureList) {
SiteSignRateVO vo;
try {
vo = future.get();
} catch (Exception e) {
log.error("获取结果失败", e);
continue;
}
voList.add(vo);
}
return voList;
}
# CountDownLatch封装
除了上诉方法阻塞线程池中任务外,还可使用countdownLatch进行阻塞。 一定要注意,如果size是0,及时返回结果,否则为0时还当作了任务,进入不了循环,一直无法countDown,会导致线程一直wait在get那。
if (Collections.isEmpty(orderIdList)) {
return vo;
}
// 分批查询(每批500条),1是因为in查询最大为1000条,2是为了速率
int taskCount = orderIdList.size() / 501 + 1;
int unitLength = 500;
// !注意分片countdown一定要在for循环里count,否则线程一直wait.
SynchroniseUtil<SiteSignRateVO> synchroniseUtil = new SynchroniseUtil<>(taskCount);
for (int i = 0; i < orderIdList.size(); i += unitLength) {
int toIndex = Math.min(i + unitLength, orderIdList.size());
List<String> subOrderIdList = orderIdList.subList(i, toIndex);
// 计算核心逻辑
SubSiteSignRateTask signRateTask = new SubSiteSignRateTask(i, toIndex, vo, param, subOrderIdList, synchroniseUtil);
siteSplitSignRateThreadPoolExecutor.execute(signRateTask);
}
// 等所有线程都计算结束,拿到所有子切片返回结果
List<SiteSignRateVO> subVoList;
try {
subVoList = synchroniseUtil.get(3, TimeUnit.MINUTES);
} catch (Exception e) {
log.error("查询已用3分钟,放弃此次查询", e);
throw new ApiException("数据量过大,请考虑缩小查找范围");
}
@Slf4j
public class SynchroniseUtil<T> {
private CountDownLatch countDownLatch;
private final List<T> result = Collections.synchronizedList(new ArrayList<T>());
public SynchroniseUtil(int count) {
this.countDownLatch = new CountDownLatch(count);
}
public List<T> get() throws InterruptedException{
countDownLatch.await();
return this.result;
}
public List<T> get(long timeout, TimeUnit timeUnit) throws Exception{
if (countDownLatch.await(timeout, timeUnit)) {
return this.result;
} else {
throw new RuntimeException("超时");
}
}
public void addResult(T resultMember) {
result.add(resultMember);
countDownLatch.countDown();
log.debug("线程{}已处理1条数据,剩下任务数为:{}",
Thread.currentThread().getName(),
countDownLatch.getCount());
}
public void addResult(List<T> resultMembers) {
result.addAll(resultMembers);
countDownLatch.countDown();
log.debug("线程{}已处理{}条数据,剩下任务数为:{}",
Thread.currentThread().getName(),
resultMembers.size(),
countDownLatch.getCount());
}
}
@Slf4j
public class SubSiteSignRateTask implements Runnable {
/** 切片的运单号 **/
private final List<String> orderIdList;
/** 返回的切片结果 **/
private final SynchroniseUtil<SiteSignRateVO> synchroniseUtil;
/** 切片结束index **/
private final Integer toIndex;
/** 切片开始index **/
private final Integer fromIndex;
/** 总vo **/
private final SiteSignRateVO originVO;
/** 参数 **/
private final SiteSignRateParam param;
public SubSiteSignRateTask(Integer fromIndex,
Integer toIndex,
SiteSignRateVO vo,
SiteSignRateParam param,
List<String> orderIdList,
SynchroniseUtil<SiteSignRateVO> synchroniseUtil) {
this.fromIndex = fromIndex;
this.toIndex = toIndex;
this.orderIdList = orderIdList;
this.synchroniseUtil = synchroniseUtil;
this.param = param;
this.originVO = vo;
}
@Override
public void run() {
log.debug("线程{}正在计算此批次记录,数组下标为[{},{}]", Thread.currentThread().getName(), fromIndex, toIndex);
SiteSignRateVO vo = SiteSignRateVO.init(originVO.getSiteId(), originVO.getSiteName());
// 处理业务逻辑
synchroniseUtil.addResult(vo);
}
}