#version 450 #define LOCAL_SZ_X 256 layout(binding = 0) readonly buffer Input0{ float image_data[]; }; layout(binding = 1) readonly buffer Input1 { float bias_data[]; }; layout(binding = 2) readonly buffer Input3{ float weight_data[]; }; layout(binding = 3) writeonly buffer Output{ float convolved_image_data[]; }; layout(push_constant) uniform pushBlock { int in_h; int in_w; int out_h; int out_w; int stride_h; int stride_w; int pad_h; int pad_w; int filter_h; int filter_w; int dilation_h; int dilation_w; int channels; int batch; int has_bias; int M; int K; int N; } p; layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in; void main() { int gx = int(gl_GlobalInvocationID.x); int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); if(gx < p.M && gy < p.N && gz < p.batch) { float sum = 0.0f; int output_y = gx / p.out_w; int output_x = gx % p.out_w; int org_y = output_y * p.stride_h - p.pad_h; int org_x = output_x * p.stride_w - p.pad_w; int weight_off = gy * p.K; int input_off = gz * p.in_h * p.in_w * p.channels + (org_y * p.in_w + org_x); for(int c = 0; c < p.channels; c++) { for(int y = 0; y < p.filter_h; y++) { for(int x = 0; x < p.filter_w; x++) { if((org_y + y * p.dilation_h >= 0) && (org_y + y * p.dilation_h < p.in_h) && (org_x + x * p.dilation_w >= 0) && (org_x + x * p.dilation_w < p.in_w)) { sum += image_data[input_off + x * p.dilation_w] * weight_data[weight_off + x]; } } input_off += p.in_w * p.dilation_h; weight_off += p.filter_w; } input_off += p.in_h * p.in_w - p.in_w * p.filter_h * p.dilation_h; } int offset = gz * p.M * p.N + gx + gy * p.M; if (p.has_bias == 1) sum += bias_data[gy]; convolved_image_data[offset] = sum; } }