Tutorial 15. Kernel Example - Graph-Cut

The graph-cut algorithm solves the max-flow min-cut problem, which can be found in many global optimization problems, such as background segmentation, image stitching and Network analysis, etc. One such graph-cut algorithm is Push and Relabel, which was designed by Andrew V. Goldberg and Robert Tarjan. This algorithm uses a breadth first approach that makes it possible to run on GPU efficiently with some challenges. The main challenge is how to resolve inter-node dependencies while revealing parallel operations.

There are multiple nodes in a system. Each node has a set of variables such as excess flow and capacities in north, south, east and west directions. The high-level idea of Push and Relabel algorithm is to call Relabel and Push in turn until no active nodes can be further pushed. For background segmentation, the final segmentation result is in height matrix.

Below is Relabel pseudo code.

if active(x) do
    my_height = HEIGHT_MAX;
    for each y = neighbor(x)
        if capacity(x,y) > 0 do
            my_height = min(my_height, height(y)+1);
        done
    end
    height(x) = my_height;// update height
done

The parallel operations of Relabel are applied to reference north, south and east neighbors, and leave west neighbor reference to SIMD1 for algorithm convergence.

Below is Push pseudo code.

if active(x) do
    foreach y = neighbor(x)
        if height(y) == height(x) – 1 do
            flow = min( capacity(x,y), excess_flow(x));
            excess_flow(x) -= flow;
            excess_flow(y) += flow;
            capacity(x,y) -= flow;
            capacity(y,x) += flow;
        done
    end
done

The parallel operations of Push are applied with separate row and column push, and preserve thread dependencies horizontally and vertically during push operations.

The following are the representative CM kernel code.

GC_Relabel_u32

extern "C" _GENX_MAIN_ void
GC_Relabel_u32(SurfaceIndex pBlockMaskIndex, SurfaceIndex pExcessFlowIndex,
               SurfaceIndex pHeightIndex, SurfaceIndex pWestCapIndex,
               SurfaceIndex pNorthCapIndex, SurfaceIndex pEastCapIndex,
               SurfaceIndex pSouthCapIndex, uint HEIGHT_MAX) {
  matrix<uchar, 1, 1> BlockMask;
  matrix<uint, 10, 8> Height10x8; // 10 GRFs
  matrix<uint, 8, 1> LBorder;
  matrix<uint, 8, 1> RBorder;

  matrix<short, 8, 8> ExcessFlow; // 4 GRFs
  matrix<short, 8, 8> WestCap;    // 4 GRFs
  matrix<short, 8, 8> NorthCap;   // 4 GRFs
  matrix<short, 8, 8> EastCap;    // 4 GRFs
  matrix<short, 8, 8> SouthCap;   // 4 GRFs

  matrix<short, 8, 8> mask;  // 4 GRFs
  matrix<uint, 8, 8> mask2;  // 8 GRFs
  vector<uint, 8> NewHeight; // 1 GRF
  vector<uint, 8> Neighbor;  // 1 GRF
  vector<uint, 2> temps;

  vector<uint, 8> Test;
  vector<uint, 8> Test2;

  uint h_pos = get_thread_origin_x();
  uint v_pos = get_thread_origin_y();

  // Skip inactive block
  read(pBlockMaskIndex, h_pos, v_pos, BlockMask);
  if (BlockMask[0][0] == 0)
    return;

  short baseX = 1; // block base
  short baseY = 1; // Offset for extra row

  // Height block origin without border
  int dX0 = 8 * h_pos * sizeof(uint); // in bytes
  int dY0 = 8 * v_pos;                // in rows

  // Height block origin with border
  int dX = (8 * h_pos - baseX) * sizeof(uint); // in bytes
  int dY = 8 * v_pos - baseY;                  // in rows

  // ExcessFlow, WestCap, NorthCap, EastCap, SouthCap origin without border
  int sX = 8 * h_pos * 2; // in bytes
  int sY = 8 * v_pos;     // in rows

  read(pExcessFlowIndex, sX, sY, ExcessFlow);
  mask.merge(1, 0, ExcessFlow > 0);
  //    if (!mask.any())
  //        return;

  // Read Height
  read(pHeightIndex, dX0, dY, Height10x8.select<8, 1, 8, 1>(0, 0));
  read(pHeightIndex, dX0, dY + 8, Height10x8.select<2, 1, 8, 1>(8, 0));
  // Read left and right Height borders
  read(pHeightIndex, dX, dY0, LBorder);
  read(pHeightIndex, dX + 9 * sizeof(uint), dY0, RBorder);

  mask.merge(1, 0, ExcessFlow > 0);

  //    mask2.merge(1, 0, Height10x8.select<8,1,8,1>(baseY,0) == HEIGHT_MAX);
  //    if (mask2.all())
  //        return;

  read(pWestCapIndex, sX, sY, WestCap);
  read(pNorthCapIndex, sX, sY, NorthCap);
  read(pEastCapIndex, sX, sY, EastCap);
  read(pSouthCapIndex, sX, sY, SouthCap);

#pragma unroll
  for (int j = 0; j < 8; j++) {
    if (mask.row(j).any()) {
      NewHeight = HEIGHT_MAX;

      // North neighbour: x, y-1
      Neighbor = Height10x8.row(baseY - 1 + j) + 1;
      NewHeight.merge(Neighbor, mask.row(j) & (NorthCap.row(j) > 0) &
                                    (Neighbor < NewHeight));

      // South neighbour: x, y+1
      Neighbor = Height10x8.row(baseY + 1 + j) + 1;
      NewHeight.merge(Neighbor, mask.row(j) & (SouthCap.row(j) > 0) &
                                    (Neighbor < NewHeight));

      // East neighbour: x+1, y
      Neighbor.select<4, 1>(0) = Height10x8.select<1, 1, 4, 1>(baseY + j, 1);
      Neighbor.select<4, 1>(3) = Height10x8.select<1, 1, 4, 1>(baseY + j, 4);
      Neighbor[7] = RBorder[j][0];
      Neighbor += 1;
      NewHeight.merge(Neighbor, mask.row(j) & (EastCap.row(j) > 0) &
                                    (Neighbor < NewHeight));

      // West pix[0]
      if (mask[j][0]) {
        // West neighbour
        temps[0] = LBorder[j][0] + 1;
        if ((WestCap[j][0] > 0) && (temps[0] < NewHeight[0]))
          NewHeight[0] = temps[0];
        // Update Height so the next pixel to the right will use this new value
        Height10x8[baseY + j][0] = NewHeight[0];
      }

      // West pix[1:7]
#pragma unroll
      for (int i = 1; i < 8; i++) {
        if (mask[j][i]) {
          //                    temps = Height10x8.select<1,1,2,2>(baseY+j,
          //                    -1+i) + 1;
          temps[0] = Height10x8[baseY + j][-1 + i] + 1;
          // West neighbour
          if ((WestCap[j][i] > 0) && (temps[0] < NewHeight[i]))
            NewHeight[i] = temps[0];
          Height10x8[baseY + j][i] = NewHeight[i];
        }
      }
    }
  }

  // Output the updated height block 8x8
  write(pHeightIndex, 8 * h_pos * sizeof(uint), 8 * v_pos,
        Height10x8.select<8, 1, 8, 1>(1, 0));
  //    cm_fence();
}

GC_V_Push_VWF_u32

extern "C" _GENX_MAIN_ void
GC_V_Push_VWF_u32(SurfaceIndex pExcessFlowIndex, SurfaceIndex pHeightIndex,
                  SurfaceIndex pNorthCapIndex, SurfaceIndex pSouthCapIndex,
                  SurfaceIndex pStatusIndex, uint HEIGHT_MAX,
                  int PhysicalThreadsWidth, int BankHeight) {
  // Output block size = 8x16
  matrix<uint, 10, 16>
      Height; // Input. Block size = 8x16, plus extra 2 rows 10x16
  matrix<short, 10, 16> ExcessFlow; // Input and output. 9x16
  matrix<short, 10, 16> NorthCap;   // Input and output. 8x16
  matrix<short, 10, 16> SouthCap;   // Input and output. 8x16

  vector<short, 16> mask;    // 1 GRF
  vector<short, 16> mask2;   // 1 GRF
  vector<short, 16> flow;    // 1 GRF
  matrix<short, 8, 16> temp; // 8 GRFs
  matrix<short, 8, 16> mask8x16;

  vector<uint, 8> element_offset(0); // 1 GRF
  vector<uint, 8> count(0);          // 1 GRF
  vector<uint, 8> tmp(0);            // 1 GRF

  // Virtual coordinate
  uint h_pos0 = get_thread_origin_x();
  uint v_pos0 = get_thread_origin_y();

  // Actual coordinate
  // x = x' % width
  // y = x' / width * height_b + y'
  // x and y is coordinate in physical thread space.  x' and y' is coordinate in
  // logical thread space. width is physical thread space. width_b is the bank
  // height.
  uint h_pos = h_pos0 % PhysicalThreadsWidth;
  uint v_pos = h_pos0 / PhysicalThreadsWidth * BankHeight +
               v_pos0; // BankHeight is the bank height in blocks

  short baseX = 0; // block base
  short baseY = 1; // Offset for extra row

  uint dX = (16 * h_pos - baseX) *
            sizeof(uint);      // in bytes, -2 pixels for DWORD aligned write
  uint dY = 8 * v_pos - baseY; // in rows

  uint nX = (16 * h_pos - baseX) *
            sizeof(short);     // in bytes, -2 pixels for DWORD aligned write
  uint nY = 8 * v_pos - baseY; // in rows

  short update = 0;

  cm_wait();

  // Read 10 rows x 16 columns
  read(pHeightIndex, dX, dY, Height.select<8, 1, 8, 1>(0, 0));
  read(pHeightIndex, dX + 8 * sizeof(uint), dY,
       Height.select<8, 1, 8, 1>(0, 8));
  read(pHeightIndex, dX, dY + 8, Height.select<2, 1, 8, 1>(8, 0));
  read(pHeightIndex, dX + 8 * sizeof(uint), dY + 8,
       Height.select<2, 1, 8, 1>(8, 8));

  mask8x16.merge(1, 0, Height.select<8, 1, 16, 1>(1, 0) == HEIGHT_MAX);
  if (mask8x16.all())
    return;

  read(pExcessFlowIndex, nX, nY, ExcessFlow.select<8, 1, 16, 1>(0, 0));
  read(pExcessFlowIndex, nX, nY + 8, ExcessFlow.select<2, 1, 16, 1>(8, 0));

  // mask8x16 for active nodes
  mask8x16.merge(1, 0, ExcessFlow.select<8, 1, 16, 1>(1, 0) <= 0);
  if (mask8x16.all())
    return;

  read(pNorthCapIndex, nX, nY, NorthCap.select<8, 1, 16, 1>(0, 0));
  read(pNorthCapIndex, nX, nY + 8, NorthCap.select<2, 1, 16, 1>(8, 0));

  read(pSouthCapIndex, nX, nY, SouthCap.select<8, 1, 16, 1>(0, 0));
  read(pSouthCapIndex, nX, nY + 8, SouthCap.select<2, 1, 16, 1>(8, 0));

#pragma unroll
  for (int j = 0; j < 8; j++) {
    // mask for checking Height < HEIGHT_MAX
    mask.merge(1, 0,
               ExcessFlow.row(baseY + j) > 0 &
                   Height.row(baseY + j) < HEIGHT_MAX);

    SIMD_IF_BEGIN(mask) {
      // North neighbour (x, y-1)
      mask2.merge(1, 0,
                  (Height.row(baseY + j - 1) == Height.row(baseY + j) - 1));
      SIMD_IF_BEGIN(mask2) {
        flow =
            cm_min<short>(NorthCap.row(baseY + j), ExcessFlow.row(baseY + j));
        SIMD_IF_BEGIN(flow != 0) {
          ExcessFlow.row(baseY + j) -= flow;
          ExcessFlow.row(baseY + j - 1) += flow;
          NorthCap.row(baseY + j) -= flow;
          SouthCap.row(baseY + j - 1) += flow;
          update = 1;
        }
        SIMD_IF_END;
      }
      SIMD_IF_END;

      // South neighbour (x, y+1)
      mask2.merge(1, 0,
                  (Height.row(baseY + j + 1) == Height.row(baseY + j) - 1));
      SIMD_IF_BEGIN(mask2) {
        flow =
            cm_min<short>(SouthCap.row(baseY + j), ExcessFlow.row(baseY + j));
        SIMD_IF_BEGIN(flow != 0) {
          ExcessFlow.row(baseY + j) -= flow;
          ExcessFlow.row(baseY + j + 1) += flow;
          SouthCap.row(baseY + j) -= flow;
          NorthCap.row(baseY + j + 1) += flow;
          update = 1;
        }
        SIMD_IF_END;
      }
      SIMD_IF_END;
    }
    SIMD_IF_END;
  }

  if (update) {
    // Output ExcessFlow and capacity blocks [9 rows x 16 cols]
    write(pExcessFlowIndex, nX, nY, ExcessFlow.select<8, 1, 16, 1>(0, 0));
    write(pExcessFlowIndex, nX, nY + 8, ExcessFlow.select<2, 1, 16, 1>(8, 0));

    write(pNorthCapIndex, nX, nY, NorthCap.select<8, 1, 16, 1>(0, 0));
    write(pNorthCapIndex, nX, nY + 8, NorthCap.select<2, 1, 16, 1>(8, 0));

    write(pSouthCapIndex, nX, nY, SouthCap.select<8, 1, 16, 1>(0, 0));
    write(pSouthCapIndex, nX, nY + 8, SouthCap.select<2, 1, 16, 1>(8, 0));
  }

  // Return true if any active node
  temp.merge(1, 0,
             (ExcessFlow.select<8, 1, 16, 1>(baseY, baseX) > 0) &
                 (Height.select<8, 1, 16, 1>(baseY, baseX) < HEIGHT_MAX));

  tmp[0] = cm_sum<short>(temp);

  // Update status to indicate at least one new height is found in the block
  //    if (temp.any()) {
  if (tmp[0]) {
    // write(pStatusIndex, ATOMIC_INC, 0, element_offset, tmp, count);
    tmp[1] = 1;
    write(pStatusIndex, ATOMIC_ADD, 0, element_offset, tmp, count);
  }

  cm_fence();
}

GC_H_Push_NR_VWF_u32

extern "C" _GENX_MAIN_ void
GC_H_Push_NR_VWF_u32(SurfaceIndex pExcessFlowIndex, SurfaceIndex pHeightIndex,
                     SurfaceIndex pWestCapIndex, SurfaceIndex pEastCapIndex,
                     uint HEIGHT_MAX, int PhysicalThreadsHeight,
                     int BankWidth) {
// Block size is 8x8
#define ROWS 8

  matrix<uint, ROWS, 16>
      Height; // Input. Block size = 8x8, plus extra column 8x12. 8GRFs
  matrix<short, ROWS, 16> ExcessFlow; // Input and output 8x12. 8GRFs
  matrix<short, ROWS, 16> WestCap;    // Input and output 8x12. 8GRFs
  matrix<short, ROWS, 16> EastCap;    // Input and output 8x12. 8GRFs

  // Transposed matrix
  matrix<short, 16, ROWS> ExcessFlow_t; // 8 GRFs
  matrix<short, 16, ROWS> WestCap_t;    // 8 GRFs
  matrix<short, 16, ROWS> EastCap_t;    // 8 GRFs
  matrix<uint, 16, ROWS> Height_t;      // 16 GRFs

  matrix<short, ROWS, 8> mask8x8;
  vector<short, 8> mask;
  vector<short, 8> mask2;
  vector<short, 8> flow;

  // Virtual coordinate
  uint h_pos0 = get_thread_origin_x();
  uint v_pos0 = get_thread_origin_y();

  // Actual coordinate
  // y = y' % Height
  // x = y' / Height * Width_b + x'
  // x and y is coordinate in physical thread space.  x' and y' is coordinate in
  // logical thread space. Height is physical thread height. Width_b is the bank
  // width.
  uint v_pos = v_pos0 % PhysicalThreadsHeight;
  uint h_pos = v_pos0 / PhysicalThreadsHeight * BankWidth +
               h_pos0; // BankWidth is the bank width in blocks

  short baseX = 2; // block base
  short baseY = 0;

  uint dX = (8 * h_pos - baseX) *
            sizeof(uint); // in bytes, 2 extra pixels for DWORD aligned write
  uint dY = ROWS * v_pos - baseY; // in rows

  uint nX = (8 * h_pos - baseX) *
            sizeof(short); // in bytes, 2 extra pixels for DWORD aligned write
  uint nY = ROWS * v_pos - baseY; // in rows

  short update = 0;

  cm_wait();

  read(pExcessFlowIndex, nX, nY, ExcessFlow);

  // mask for active nodes
  mask8x8.merge(1, 0, ExcessFlow.select<8, 1, 8, 1>(baseY, baseX) > 0);
  if (!mask8x8.any())
    return;

  Transpose_8x16_To_16x8_Short(ExcessFlow, ExcessFlow_t);

  // Read 8 rows x 8 columns, increase to 8x12 for DWORD aligned write
  read(pHeightIndex, dX, dY, Height.select<ROWS, 1, 8, 1>(0, 0));
  read(pHeightIndex, dX + 8 * sizeof(uint), dY,
       Height.select<ROWS, 1, 4, 1>(0, 8));

  // mask for active nodes
  mask8x8.merge(1, 0, Height.select<ROWS, 1, 8, 1>(baseY, baseX) < HEIGHT_MAX);
  if (!mask8x8.any())
    return;

  Transpose_8x16_To_16x8_Uint(Height, Height_t);

  read(pWestCapIndex, nX, nY, WestCap);
  read(pEastCapIndex, nX, nY, EastCap);

  Transpose_8x16_To_16x8_Short(WestCap, WestCap_t);
  Transpose_8x16_To_16x8_Short(EastCap, EastCap_t);

  baseX = 0; // Transposed block base
  baseY = 2;

#pragma unroll
  for (int j = 0; j < 8; j++) {

    // mask for checking Height < HEIGHT_MAX
    mask.merge(1, 0,
               ExcessFlow_t.row(baseY + j) > 0 &
                   Height_t.row(baseY + j) < HEIGHT_MAX);

    SIMD_IF_BEGIN(mask) {
      // West neighbour (x, y-1)
      mask2.merge(1, 0,
                  (Height_t.row(baseY + j - 1) == Height_t.row(baseY + j) - 1));
      SIMD_IF_BEGIN(mask2) {
        flow = cm_min<short>(WestCap_t.row(baseY + j),
                             ExcessFlow_t.row(baseY + j));
        SIMD_IF_BEGIN(flow != 0) {
          ExcessFlow_t.row(baseY + j) -= flow;
          ExcessFlow_t.row(baseY + j - 1) += flow;
          WestCap_t.row(baseY + j) -= flow;
          EastCap_t.row(baseY + j - 1) += flow;
          update = 1;
        }
        SIMD_IF_END;
      }
      SIMD_IF_END;

      // East neighbour (x, y+1)
      mask2.merge(1, 0,
                  (Height_t.row(baseY + j + 1) == Height_t.row(baseY + j) - 1));
      SIMD_IF_BEGIN(mask2) {
        flow = cm_min<short>(EastCap_t.row(baseY + j),
                             ExcessFlow_t.row(baseY + j));
        SIMD_IF_BEGIN(flow != 0) {
          ExcessFlow_t.row(baseY + j) -= flow;
          ExcessFlow_t.row(baseY + j + 1) += flow;
          EastCap_t.row(baseY + j) -= flow;
          WestCap_t.row(baseY + j + 1) += flow;
          update = 1;
        }
        SIMD_IF_END;
      }
      SIMD_IF_END;
    }
    SIMD_IF_END;
  }

  // Reverse transpose
  Transpose_16x8_To_8x16_Short(ExcessFlow_t, ExcessFlow);
  Transpose_16x8_To_8x16_Short(WestCap_t, WestCap);
  Transpose_16x8_To_8x16_Short(EastCap_t, EastCap);

  //    baseX = 2;   // Restore block base
  //    baseY = 0;

  if (update) {
    // Output ExcessFlow and capacity blocks [8 rows x 10 cols], 12 to be DWORD
    // aligned
    write(pExcessFlowIndex, nX, nY, ExcessFlow);
    write(pWestCapIndex, nX, nY, WestCap);
    write(pEastCapIndex, nX, nY, EastCap);
  }

  cm_fence();
}

The source code also includes more variations of these kernels for higher performance. They are beyond of tutorial scope.