7 #include <Metal/Metal.h>
11 #include "flutter/fml/backtrace.h"
12 #include "flutter/fml/closure.h"
13 #include "flutter/fml/logging.h"
14 #include "flutter/fml/trace_event.h"
28 ComputePassMTL::ComputePassMTL(std::weak_ptr<const Context> context,
29 id<MTLCommandBuffer> buffer)
30 : ComputePass(
std::move(context)), buffer_(buffer) {
39 bool ComputePassMTL::IsValid()
const {
43 void ComputePassMTL::OnSetLabel(
const std::string& label) {
50 bool ComputePassMTL::OnEncodeCommands(
const Context& context,
51 const ISize& grid_size,
52 const ISize& thread_group_size)
const {
53 TRACE_EVENT0(
"impeller",
"ComputePassMTL::EncodeCommands");
58 FML_DCHECK(!grid_size.IsEmpty() && !thread_group_size.IsEmpty());
61 auto compute_command_encoder = [buffer_ computeCommandEncoder];
63 if (!compute_command_encoder) {
67 if (!label_.empty()) {
68 [compute_command_encoder setLabel:@(label_.c_str())];
73 fml::ScopedCleanupClosure auto_end(
74 [compute_command_encoder]() { [compute_command_encoder endEncoding]; });
76 return EncodeCommands(context.GetResourceAllocator(), compute_command_encoder,
77 grid_size, thread_group_size);
91 : encoder_(encoder) {}
98 if (pipeline == pipeline_) {
101 pipeline_ = pipeline;
102 [encoder_ setComputePipelineState:pipeline_];
105 id<MTLComputePipelineState>
GetPipeline()
const {
return pipeline_; }
107 void SetBuffer(uint64_t index, uint64_t offset, id<MTLBuffer> buffer) {
108 auto found = buffers_.find(index);
109 if (found != buffers_.end() && found->second.buffer == buffer) {
111 if (found->second.offset == offset) {
117 found->second.offset = offset;
119 [encoder_ setBufferOffset:offset atIndex:index];
123 buffers_[index] = {buffer,
static_cast<size_t>(offset)};
124 [encoder_ setBuffer:buffer offset:offset atIndex:index];
128 auto found = textures_.find(index);
129 if (found != textures_.end() && found->second == texture) {
133 textures_[index] = texture;
134 [encoder_ setTexture:texture atIndex:index];
138 void SetSampler(uint64_t index, id<MTLSamplerState> sampler) {
139 auto found = samplers_.find(index);
140 if (found != samplers_.end() && found->second == sampler) {
144 samplers_[index] = sampler;
145 [encoder_ setSamplerState:sampler atIndex:index];
150 struct BufferOffsetPair {
151 id<MTLBuffer> buffer =
nullptr;
154 using BufferMap = std::map<uint64_t, BufferOffsetPair>;
155 using TextureMap = std::map<uint64_t, id<MTLTexture>>;
156 using SamplerMap = std::map<uint64_t, id<MTLSamplerState>>;
158 const id<MTLComputeCommandEncoder> encoder_;
159 id<MTLComputePipelineState> pipeline_ =
nullptr;
161 TextureMap textures_;
162 SamplerMap samplers_;
173 auto device_buffer = view.
buffer->GetDeviceBuffer(allocator);
174 if (!device_buffer) {
202 id<MTLComputeCommandEncoder> encoder,
203 const ISize& grid_size,
204 const ISize& thread_group_size)
const {
205 if (grid_size.width == 0 || grid_size.height == 0) {
209 ComputePassBindingsCache pass_bindings(encoder);
211 fml::closure pop_debug_marker = [encoder]() { [encoder popDebugGroup]; };
213 #ifdef IMPELLER_DEBUG
214 fml::ScopedCleanupClosure auto_pop_debug_marker(pop_debug_marker);
215 if (!command.label.empty()) {
216 [encoder pushDebugGroup:@(command.label.c_str())];
218 auto_pop_debug_marker.Release();
222 pass_bindings.SetComputePipelineState(
226 for (
const auto& buffer : command.bindings.buffers) {
227 if (!
Bind(pass_bindings, *allocator, buffer.first,
228 buffer.second.view.resource)) {
233 for (
const auto& data : command.bindings.sampled_images) {
234 if (!
Bind(pass_bindings, data.first, *data.second.sampler.resource,
235 *data.second.texture.resource)) {
243 auto width = grid_size.width;
244 auto height = grid_size.height;
246 auto maxTotalThreadsPerThreadgroup =
static_cast<int64_t
>(
247 pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup);
251 int64_t threadGroups = std::max(
252 static_cast<int64_t
>(
253 std::ceil(width * 1.0 / maxTotalThreadsPerThreadgroup * 1.0)),
255 [encoder dispatchThreadgroups:MTLSizeMake(threadGroups, 1, 1)
256 threadsPerThreadgroup:MTLSizeMake(maxTotalThreadsPerThreadgroup,
259 while (width * height > maxTotalThreadsPerThreadgroup) {
260 width = std::max(1LL, width / 2);
261 height = std::max(1LL, height / 2);
264 auto size = MTLSizeMake(width, height, 1);
265 [encoder dispatchThreadgroups:size threadsPerThreadgroup:size];