7 #include "flutter/fml/trace_event.h"
15 ComputePassVK::ComputePassVK(std::weak_ptr<const Context> context,
16 std::weak_ptr<CommandBufferVK> command_buffer)
17 : ComputePass(
std::move(context)),
18 command_buffer_(
std::move(command_buffer)) {
24 bool ComputePassVK::IsValid()
const {
28 void ComputePassVK::OnSetLabel(
const std::string& label) {
36 const vk::CommandBuffer& buffer) {
39 barrier.
src_access = vk::AccessFlagBits::eTransferWrite;
40 barrier.
src_stage = vk::PipelineStageFlagBits::eTransfer;
41 barrier.
dst_access = vk::AccessFlagBits::eShaderRead;
42 barrier.
dst_stage = vk::PipelineStageFlagBits::eComputeShader;
44 barrier.
new_layout = vk::ImageLayout::eShaderReadOnlyOptimal;
55 const vk::CommandBuffer& buffer) {
60 const vk::CommandBuffer& buffer) {
61 for (
const auto& command : commands) {
73 size_t command_count) {
83 std::unordered_map<uint32_t, vk::DescriptorBufferInfo> buffers;
84 std::unordered_map<uint32_t, vk::DescriptorImageInfo> images;
85 std::vector<vk::WriteDescriptorSet> writes;
86 auto bind_images = [&encoder,
90 ](
const Bindings& bindings) ->
bool {
91 for (
const auto& [index, data] : bindings.sampled_images) {
92 auto texture = data.texture.resource;
96 if (!encoder.
Track(texture) ||
103 vk::DescriptorImageInfo image_info;
104 image_info.imageLayout = vk::ImageLayout::eShaderReadOnlyOptimal;
106 image_info.imageView = texture_vk.GetImageView();
108 vk::WriteDescriptorSet write_set;
109 write_set.dstSet = vk_desc_set.value();
110 write_set.dstBinding = slot.
binding;
111 write_set.descriptorCount = 1u;
112 write_set.descriptorType = vk::DescriptorType::eCombinedImageSampler;
113 write_set.pImageInfo = &(images[slot.
binding] = image_info);
115 writes.push_back(write_set);
121 auto bind_buffers = [&allocator,
127 ](
const Bindings& bindings) ->
bool {
128 for (
const auto& [buffer_index, data] : bindings.buffers) {
129 const auto& buffer_view = data.view.resource.buffer;
131 auto device_buffer = buffer_view->GetDeviceBuffer(allocator);
132 if (!device_buffer) {
133 VALIDATION_LOG <<
"Failed to get device buffer for vertex binding";
142 if (!encoder.Track(device_buffer)) {
146 uint32_t offset = data.view.resource.range.offset;
148 vk::DescriptorBufferInfo buffer_info;
149 buffer_info.buffer = buffer;
150 buffer_info.offset = offset;
151 buffer_info.range = data.view.resource.range.length;
154 auto layout_it = std::find_if(desc_set.begin(), desc_set.end(),
156 return layout.binding == uniform.binding;
158 if (layout_it == desc_set.end()) {
159 VALIDATION_LOG <<
"Failed to get descriptor set layout for binding "
163 auto layout = *layout_it;
165 vk::WriteDescriptorSet write_set;
166 write_set.dstSet = vk_desc_set.value();
167 write_set.dstBinding = uniform.
binding;
168 write_set.descriptorCount = 1u;
170 write_set.pBufferInfo = &(buffers[uniform.
binding] = buffer_info);
172 writes.push_back(write_set);
181 context.
GetDevice().updateDescriptorSets(writes, {});
183 encoder.GetCommandBuffer().bindDescriptorSets(
184 vk::PipelineBindPoint::eCompute,
187 {vk::DescriptorSet{*vk_desc_set}},
193 bool ComputePassVK::OnEncodeCommands(
const Context& context,
194 const ISize& grid_size,
195 const ISize& thread_group_size)
const {
196 TRACE_EVENT0(
"impeller",
"ComputePassVK::EncodeCommands");
201 FML_DCHECK(!grid_size.IsEmpty() && !thread_group_size.IsEmpty());
204 auto command_buffer = command_buffer_.lock();
205 if (!command_buffer) {
206 VALIDATION_LOG <<
"Command buffer died before commands could be encoded.";
209 auto encoder = command_buffer->GetEncoder();
214 fml::ScopedCleanupClosure pop_marker(
215 [&encoder]() { encoder->PopDebugGroup(); });
216 if (!label_.empty()) {
217 encoder->PushDebugGroup(label_.c_str());
219 pop_marker.Release();
221 auto cmd_buffer = encoder->GetCommandBuffer();
224 VALIDATION_LOG <<
"Could not update binding layouts for compute pass.";
229 TRACE_EVENT0(
"impeller",
"EncodeComputePassCommands");
232 if (!command.pipeline) {
238 cmd_buffer.bindPipeline(vk::PipelineBindPoint::eCompute,
239 pipeline_vk.GetPipeline());
251 auto device_properties = vk_context.GetPhysicalDevice().getProperties();
253 auto max_wg_size = device_properties.limits.maxComputeWorkGroupSize;
255 int64_t width = grid_size.width;
256 int64_t height = grid_size.height;
261 int64_t threadGroups = std::max(
262 static_cast<int64_t
>(std::ceil(width * 1.0 / max_wg_size[0] * 1.0)),
264 cmd_buffer.dispatch(threadGroups, 1, 1);
266 while (width > max_wg_size[0]) {
267 width = std::max(
static_cast<int64_t
>(1), width / 2);
269 while (height > max_wg_size[1]) {
270 height = std::max(
static_cast<int64_t
>(1), height / 2);
272 cmd_buffer.dispatch(width, height, 1);