7 #include <Metal/Metal.h>
11 #include "flutter/fml/backtrace.h"
12 #include "flutter/fml/closure.h"
13 #include "flutter/fml/logging.h"
14 #include "flutter/fml/trace_event.h"
29 ComputePassMTL::ComputePassMTL(std::weak_ptr<const Context> context,
30 id<MTLCommandBuffer> buffer)
31 : ComputePass(
std::move(context)), buffer_(buffer) {
40 bool ComputePassMTL::IsValid()
const {
44 void ComputePassMTL::OnSetLabel(
const std::string& label) {
51 bool ComputePassMTL::OnEncodeCommands(
const Context& context,
52 const ISize& grid_size,
53 const ISize& thread_group_size)
const {
54 TRACE_EVENT0(
"impeller",
"ComputePassMTL::EncodeCommands");
59 FML_DCHECK(!grid_size.IsEmpty() && !thread_group_size.IsEmpty());
62 auto compute_command_encoder = [buffer_ computeCommandEncoder];
64 if (!compute_command_encoder) {
68 if (!label_.empty()) {
69 [compute_command_encoder setLabel:@(label_.c_str())];
74 fml::ScopedCleanupClosure auto_end(
75 [compute_command_encoder]() { [compute_command_encoder endEncoding]; });
77 return EncodeCommands(context.GetResourceAllocator(), compute_command_encoder,
78 grid_size, thread_group_size);
92 : encoder_(encoder) {}
99 if (pipeline == pipeline_) {
102 pipeline_ = pipeline;
103 [encoder_ setComputePipelineState:pipeline_];
106 id<MTLComputePipelineState>
GetPipeline()
const {
return pipeline_; }
108 void SetBuffer(uint64_t index, uint64_t offset, id<MTLBuffer> buffer) {
109 auto found = buffers_.find(index);
110 if (found != buffers_.end() && found->second.buffer == buffer) {
112 if (found->second.offset == offset) {
118 found->second.offset = offset;
120 [encoder_ setBufferOffset:offset atIndex:index];
124 buffers_[index] = {buffer,
static_cast<size_t>(offset)};
125 [encoder_ setBuffer:buffer offset:offset atIndex:index];
129 auto found = textures_.find(index);
130 if (found != textures_.end() && found->second == texture) {
134 textures_[index] = texture;
135 [encoder_ setTexture:texture atIndex:index];
139 void SetSampler(uint64_t index, id<MTLSamplerState> sampler) {
140 auto found = samplers_.find(index);
141 if (found != samplers_.end() && found->second == sampler) {
145 samplers_[index] = sampler;
146 [encoder_ setSamplerState:sampler atIndex:index];
151 struct BufferOffsetPair {
152 id<MTLBuffer> buffer =
nullptr;
155 using BufferMap = std::map<uint64_t, BufferOffsetPair>;
156 using TextureMap = std::map<uint64_t, id<MTLTexture>>;
157 using SamplerMap = std::map<uint64_t, id<MTLSamplerState>>;
159 const id<MTLComputeCommandEncoder> encoder_;
160 id<MTLComputePipelineState> pipeline_ =
nullptr;
162 TextureMap textures_;
163 SamplerMap samplers_;
174 auto device_buffer = view.
buffer->GetDeviceBuffer(allocator);
175 if (!device_buffer) {
203 id<MTLComputeCommandEncoder> encoder,
204 const ISize& grid_size,
205 const ISize& thread_group_size)
const {
206 if (grid_size.width == 0 || grid_size.height == 0) {
210 ComputePassBindingsCache pass_bindings(encoder);
212 fml::closure pop_debug_marker = [encoder]() { [encoder popDebugGroup]; };
213 for (
const ComputeCommand& command :
commands_) {
214 #ifdef IMPELLER_DEBUG
215 fml::ScopedCleanupClosure auto_pop_debug_marker(pop_debug_marker);
216 if (!command.label.empty()) {
217 [encoder pushDebugGroup:@(command.label.c_str())];
219 auto_pop_debug_marker.Release();
223 pass_bindings.SetComputePipelineState(
227 for (
const BufferAndUniformSlot& buffer : command.bindings.buffers) {
228 if (!
Bind(pass_bindings, *allocator, buffer.slot.ext_res_0,
229 buffer.view.resource)) {
234 for (
const TextureAndSampler& data : command.bindings.sampled_images) {
235 if (!
Bind(pass_bindings, data.slot.texture_index, *data.sampler,
236 *data.texture.resource)) {
244 auto width = grid_size.width;
245 auto height = grid_size.height;
247 auto maxTotalThreadsPerThreadgroup =
static_cast<int64_t
>(
248 pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup);
252 int64_t threadGroups = std::max(
253 static_cast<int64_t
>(
254 std::ceil(width * 1.0 / maxTotalThreadsPerThreadgroup * 1.0)),
256 [encoder dispatchThreadgroups:MTLSizeMake(threadGroups, 1, 1)
257 threadsPerThreadgroup:MTLSizeMake(maxTotalThreadsPerThreadgroup,
260 while (width * height > maxTotalThreadsPerThreadgroup) {
261 width = std::max(1LL, width / 2);
262 height = std::max(1LL, height / 2);
265 auto size = MTLSizeMake(width, height, 1);
266 [encoder dispatchThreadgroups:size threadsPerThreadgroup:size];