7 #include "flutter/fml/trace_event.h"
16 ComputePassVK::ComputePassVK(std::weak_ptr<const Context> context,
17 std::weak_ptr<CommandBufferVK> command_buffer)
18 : ComputePass(
std::move(context)),
19 command_buffer_(
std::move(command_buffer)) {
25 bool ComputePassVK::IsValid()
const {
29 void ComputePassVK::OnSetLabel(
const std::string& label) {
37 const vk::CommandBuffer& buffer) {
40 barrier.
src_access = vk::AccessFlagBits::eTransferWrite;
41 barrier.
src_stage = vk::PipelineStageFlagBits::eTransfer;
42 barrier.
dst_access = vk::AccessFlagBits::eShaderRead;
43 barrier.
dst_stage = vk::PipelineStageFlagBits::eComputeShader;
45 barrier.
new_layout = vk::ImageLayout::eShaderReadOnlyOptimal;
56 const vk::CommandBuffer& buffer) {
61 const vk::CommandBuffer& buffer) {
70 bool ComputePassVK::OnEncodeCommands(
const Context& context,
71 const ISize& grid_size,
72 const ISize& thread_group_size)
const {
73 TRACE_EVENT0(
"impeller",
"ComputePassVK::EncodeCommands");
78 FML_DCHECK(!grid_size.IsEmpty() && !thread_group_size.IsEmpty());
81 auto command_buffer = command_buffer_.lock();
82 if (!command_buffer) {
83 VALIDATION_LOG <<
"Command buffer died before commands could be encoded.";
86 auto encoder = command_buffer->GetEncoder();
91 fml::ScopedCleanupClosure pop_marker(
92 [&encoder]() { encoder->PopDebugGroup(); });
93 if (!label_.empty()) {
94 encoder->PushDebugGroup(label_.c_str());
98 auto cmd_buffer = encoder->GetCommandBuffer();
101 VALIDATION_LOG <<
"Could not update binding layouts for compute pass.";
104 auto desc_sets_result =
106 if (!desc_sets_result.ok()) {
109 auto desc_sets = desc_sets_result.value();
111 TRACE_EVENT0(
"impeller",
"EncodeComputePassCommands");
112 size_t desc_index = 0;
116 cmd_buffer.bindPipeline(vk::PipelineBindPoint::eCompute,
117 pipeline_vk.GetPipeline());
118 cmd_buffer.bindDescriptorSets(
119 vk::PipelineBindPoint::eCompute,
120 pipeline_vk.GetPipelineLayout(),
122 {vk::DescriptorSet{desc_sets[desc_index]}},
128 auto device_properties = vk_context.GetPhysicalDevice().getProperties();
130 auto max_wg_size = device_properties.limits.maxComputeWorkGroupSize;
132 int64_t width = grid_size.width;
133 int64_t height = grid_size.height;
138 int64_t threadGroups = std::max(
139 static_cast<int64_t
>(std::ceil(width * 1.0 / max_wg_size[0] * 1.0)),
141 cmd_buffer.dispatch(threadGroups, 1, 1);
143 while (width > max_wg_size[0]) {
144 width = std::max(
static_cast<int64_t
>(1), width / 2);
146 while (height > max_wg_size[1]) {
147 height = std::max(
static_cast<int64_t
>(1), height / 2);
149 cmd_buffer.dispatch(width, height, 1);