12 #include "vulkan/vulkan_structs.hpp"
16 ComputePassVK::ComputePassVK(std::shared_ptr<const Context> context,
17 std::shared_ptr<CommandBufferVK> command_buffer)
18 : ComputePass(
std::move(context)),
19 command_buffer_(
std::move(command_buffer)) {
22 max_wg_size_ = ContextVK::Cast(*context_)
25 .limits.maxComputeWorkGroupSize;
29 ComputePassVK::~ComputePassVK() =
default;
31 bool ComputePassVK::IsValid()
const {
35 void ComputePassVK::OnSetLabel(
const std::string& label) {
43 void ComputePassVK::SetCommandLabel(std::string_view label) {
45 command_buffer_->PushDebugGroup(label);
51 void ComputePassVK::SetPipeline(
52 const std::shared_ptr<Pipeline<ComputePipelineDescriptor>>& pipeline) {
53 const auto& pipeline_vk = ComputePipelineVK::Cast(*pipeline);
54 const vk::CommandBuffer& command_buffer_vk =
55 command_buffer_->GetCommandBuffer();
56 command_buffer_vk.bindPipeline(vk::PipelineBindPoint::eCompute,
57 pipeline_vk.GetPipeline());
58 pipeline_layout_ = pipeline_vk.GetPipelineLayout();
60 auto descriptor_result = command_buffer_->AllocateDescriptorSets(
61 pipeline_vk.GetDescriptorSetLayout(), ContextVK::Cast(*context_));
62 if (!descriptor_result.ok()) {
65 descriptor_set_ = descriptor_result.value();
66 pipeline_valid_ =
true;
70 fml::Status ComputePassVK::Compute(
const ISize& grid_size) {
71 if (grid_size.IsEmpty() || !pipeline_valid_) {
72 bound_image_offset_ = 0u;
73 bound_buffer_offset_ = 0u;
74 descriptor_write_offset_ = 0u;
76 pipeline_valid_ =
false;
77 return fml::Status(fml::StatusCode::kCancelled,
78 "Invalid pipeline or empty grid.");
81 const ContextVK& context_vk = ContextVK::Cast(*context_);
82 for (
auto i = 0u; i < descriptor_write_offset_; i++) {
83 write_workspace_[i].dstSet = descriptor_set_;
86 context_vk.GetDevice().updateDescriptorSets(descriptor_write_offset_,
87 write_workspace_.data(), 0u, {});
88 const vk::CommandBuffer& command_buffer_vk =
89 command_buffer_->GetCommandBuffer();
91 command_buffer_vk.bindDescriptorSets(
92 vk::PipelineBindPoint::eCompute,
101 int64_t width = grid_size.width;
102 int64_t height = grid_size.height;
106 command_buffer_vk.dispatch(width, 1, 1);
108 while (width > max_wg_size_[0]) {
109 width = std::max(
static_cast<int64_t
>(1), width / 2);
111 while (height > max_wg_size_[1]) {
112 height = std::max(
static_cast<int64_t
>(1), height / 2);
114 command_buffer_vk.dispatch(width, height, 1);
117 #ifdef IMPELLER_DEBUG
119 command_buffer_->PopDebugGroup();
124 bound_image_offset_ = 0u;
125 bound_buffer_offset_ = 0u;
126 descriptor_write_offset_ = 0u;
128 pipeline_valid_ =
false;
130 return fml::Status();
134 bool ComputePassVK::BindResource(
ShaderStage stage,
136 const ShaderUniformSlot& slot,
137 const ShaderMetadata* metadata,
139 return BindResource(slot.binding,
type, view);
143 bool ComputePassVK::BindResource(
ShaderStage stage,
145 const SampledImageSlot& slot,
146 const ShaderMetadata* metadata,
147 std::shared_ptr<const Texture> texture,
148 raw_ptr<const Sampler> sampler) {
152 if (!texture->IsValid() || !sampler) {
155 const TextureVK& texture_vk = TextureVK::Cast(*texture);
156 const SamplerVK& sampler_vk = SamplerVK::Cast(*sampler);
158 if (!command_buffer_->Track(texture)) {
162 vk::DescriptorImageInfo image_info;
163 image_info.imageLayout = vk::ImageLayout::eShaderReadOnlyOptimal;
164 image_info.sampler = sampler_vk.GetSampler();
165 image_info.imageView = texture_vk.GetImageView();
166 image_workspace_[bound_image_offset_++] = image_info;
168 vk::WriteDescriptorSet write_set;
169 write_set.dstBinding = slot.binding;
170 write_set.descriptorCount = 1u;
172 write_set.pImageInfo = &image_workspace_[bound_image_offset_ - 1];
174 write_workspace_[descriptor_write_offset_++] = write_set;
178 bool ComputePassVK::BindResource(
size_t binding,
185 auto buffer = DeviceBufferVK::Cast(*view.GetBuffer()).GetBuffer();
190 std::shared_ptr<const DeviceBuffer> device_buffer = view.TakeBuffer();
191 if (device_buffer && !command_buffer_->Track(device_buffer)) {
195 uint32_t
offset = view.GetRange().offset;
197 vk::DescriptorBufferInfo buffer_info;
198 buffer_info.buffer = buffer;
199 buffer_info.offset =
offset;
200 buffer_info.range = view.GetRange().length;
201 buffer_workspace_[bound_buffer_offset_++] = buffer_info;
203 vk::WriteDescriptorSet write_set;
204 write_set.dstBinding = binding;
205 write_set.descriptorCount = 1u;
207 write_set.pBufferInfo = &buffer_workspace_[bound_buffer_offset_ - 1];
209 write_workspace_[descriptor_write_offset_++] = write_set;
220 void ComputePassVK::AddBufferMemoryBarrier() {
221 vk::MemoryBarrier barrier;
222 barrier.srcAccessMask = vk::AccessFlagBits::eShaderWrite;
223 barrier.dstAccessMask = vk::AccessFlagBits::eShaderRead;
225 command_buffer_->GetCommandBuffer().pipelineBarrier(
226 vk::PipelineStageFlagBits::eComputeShader,
227 vk::PipelineStageFlagBits::eComputeShader, {}, 1, &barrier, 0, {}, 0, {});
231 void ComputePassVK::AddTextureMemoryBarrier() {
232 vk::MemoryBarrier barrier;
233 barrier.srcAccessMask = vk::AccessFlagBits::eShaderWrite;
234 barrier.dstAccessMask = vk::AccessFlagBits::eShaderRead;
236 command_buffer_->GetCommandBuffer().pipelineBarrier(
237 vk::PipelineStageFlagBits::eComputeShader,
238 vk::PipelineStageFlagBits::eComputeShader, {}, 1, &barrier, 0, {}, 0, {});
242 bool ComputePassVK::EncodeCommands()
const {
251 vk::MemoryBarrier barrier;
252 barrier.srcAccessMask = vk::AccessFlagBits::eShaderWrite;
253 barrier.dstAccessMask =
254 vk::AccessFlagBits::eIndexRead | vk::AccessFlagBits::eVertexAttributeRead;
256 command_buffer_->GetCommandBuffer().pipelineBarrier(
257 vk::PipelineStageFlagBits::eComputeShader,
258 vk::PipelineStageFlagBits::eVertexInput, {}, 1, &barrier, 0, {}, 0, {});
constexpr vk::DescriptorType ToVKDescriptorType(DescriptorType type)
static constexpr size_t kMaxBindings