SYCL* Unified Shared Memory Code Walkthrough

ID 标签 657786
已更新 3/22/2022
版本 Latest
公共

author-image

作者

This walkthrough introduces Unified Shared Memory (USM) as an alternative to buffers for managing and accessing memory from the host and device. The program calculates if each point in a two-dimensional complex plane exists in the set, by using parallel computing patterns and SYCL*. This code walkthrough uses a Mandelbrot sample to explore USM.

Download the Mandlebrot source from GitHub.

Summary of Unified Shared Memory

  • USM is a language feature of SYCL.
  • USM requires hardware support for unified virtual address space (this allows for consistent pointer values between the host and device).
  • All memory is allocated by the host, but it offers three distinct allocation types:
    • Host: located on the host, accessible by the host or device.
    • Device: located on the device, accessible only by the device.
    • Shared: location is the host or device (managed by the compiler), accessible by the host or device.

Include Headers

Along with some other common libraries, the Mandelbrot code sample makes use of Sean's Toolbox* (STB) for data visualization. The STB libraries allow for the reading/writing of image files.

#include <complex>
#include <exception>
#include <iomanip>
#include <iostream>

// stb/*.h files can be found in the dev-utilities include folder.
// Example: $ONEAPI_ROOT/dev-utilities/<version>/include/stb/*.h

#define STB_IMAGE_IMPLEMENTATION
#include "stb/stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb/stb_image_write.h"

The Mandelbrot code sample also utilizes functionality provided by dpc_common:

// dpc_common.hpp can be found in the dev-utilities include folder.
// Example: $ONEAPI_ROOT/dev-utilities/<version>/include/dpc_common.hpp

#include "dpc_common.hpp"

main.cpp Driver Functions

The driver function, main.cpp, contains the infrastructure to execute and evaluate the computation of the Mandelbrot set.

Queue Creation

The queue is created in main using the default selector, which first attempts to launch a kernel code on the GPU, and then falls back to the Host/CPU if no compatible device is found. The queue utilizes the dpc_common exception handler, which allows for asynchronous exception handling of your kernel code.

// Create a queue on the default device. Set SYCL_DEVICE_TYPE environment
// variable to (CPU|GPU|FPGA|HOST) to change the device.

queue q(default_selector{}, dpc_common::exception_handler);

Show Device

The ShowDevice() function displays important information about the chosen device.

void ShowDevice(queue &q) {

  // Output platform and device information.
  
  auto device = q.get_device();
  auto p_name = device.get_platform().get_info<info::platform::name>();
  cout << std::setw(20) << "Platform Name: " << p_name << "\n";
  auto p_version = device.get_platform().get_info<info::platform::version>();
  cout << std::setw(20) << "Platform Version: " << p_version << "\n";
  auto d_name = device.get_info<info::device::name>();
  cout << std::setw(20) << "Device Name: " << d_name << "\n";
  auto max_work_group = device.get_info<info::device::max_work_group_size>();
  cout << std::setw(20) << "Max Work Group: " << max_work_group << "\n";
  auto max_compute_units = device.get_info<info::device::max_compute_units>();
  cout << std::setw(20) << "Max Compute Units: " << max_compute_units << "\n\n";
}

Execute

The Execute() function initializes the MandelParallelUsm object, uses it to evaluate the Mandelbrot set, and outputs the results.

void Execute(queue &q) {

  // Demonstrate the Mandelbrot calculation serial and parallel.

#ifdef MANDELBROT_USM
  cout << "Parallel Mandelbrot set using USM.\n";
  MandelParallelUsm m_par(row_size, col_size, max_iterations, &q);
#else
  cout << "Parallel Mandelbrot set using buffers.\n";
  MandelParallel m_par(row_size, col_size, max_iterations);
#endif

  MandelSerial m_ser(row_size, col_size, max_iterations);

  // Run the code once to trigger JIT.

  m_par.Evaluate(q);

  // Run the parallel version and time it.

  dpc_common::TimeInterval t_par;
  for (int i = 0; i < repetitions; ++i) m_par.Evaluate(q);
  double parallel_time = t_par.Elapsed();

  // Print the results.

  m_par.Print();
  m_par.WriteImage();

  // Run the serial version.

  dpc_common::TimeInterval t_ser;
  m_ser.Evaluate();
  double serial_time = t_ser.Elapsed();

  // Report the results.

  cout << std::setw(20) << "Serial time: " << serial_time << "s\n";
  cout << std::setw(20) << "Parallel time: " << (parallel_time / repetitions)
       << "s\n";

  // Validate.

  m_par.Verify(m_ser);
}

Mandelbrot USM Usage

Mandel Parameter Class

The MandelParameter struct contains all the necessary functionality to calculate the Mandelbrot set.

ComplexF Datatype

The MandelParameter defines a datatype ComplexF, which represents a complex floating-point number.

typedef std::complex<float> ComplexF;

Point Function

The Point() function takes a complex point, c, as an argument and determines whether or not it belongs to the Mandelbrot set. The function checks for how many iterations (up to an arbitrary max_iterations) that the parameter, z, remains bound to given the recursive function, zn+1 = (zn)2 + c, where z0 = 0. Then it returns the number of iterations.

int Point(const ComplexF &c) const {
  int count = 0;
  ComplexF z = 0;

  for (int i = 0; i < max_iterations_; ++i) {
    auto r = z.real();
    auto im = z.imag();

  // Leave loop if diverging.

  if (((r * r) + (im * im)) >= 4.0f) {
    break;
  }

  // z = z * z + c;

  z = complex_square(z) + c;
    count++;
  }

  return count;
}

Scale Row and Column

The scale functions convert row/column indices to coordinates within the complex plane. This is necessary to convert array indices to their corresponding complex coordinates. This application can be seen below in the Mandle Parallel USM Class section.

// Scale from 0..row_count to -1.5..0.5 

float ScaleRow(int i) const { return -1.5f + (i * (2.0f / row_count_)); }

// Scale from 0..col_count to -1..1

float ScaleCol(int i) const { return -1.0f + (i * (2.0f / col_count_)); }

Mandle Class

The Mandel class is the parent class that MandelParallelUsm inherits. It contains member functions for outputting the data visualization, which is addressed in the Other Functions section.

Member Variables

  • MandelParameters p_: a MandelParameters object.
  • int *data_: a pointer to the memory for storing the output data.

Mandle Parallel USM Class

This class is derived from the Mandel class and handles all the device code for offloading the Mandelbrot calculation using USM.

Constructor Device Initialization

The MandelParallelUSM constructor first calls the Mandel constructor, which assigns the values of the arguments to their corresponding member variables. It passes the address of the queue object to the member variable, q, so that it can later be used to launch the device code. Finally, it calls the Alloc() virtual member function.

MandelParallelUsm(int row_count, int col_count, int max_iterations, queue *q)
    : Mandel(row_count, col_count, max_iterations) {
  this->q = q;
  Alloc();
}

Alloc USM Initialization

The Alloc() virtual member function is overridden in the MandelParallelUsm class to enable USM. It calls malloc_shared(), which creates and returns the address to a block of memory. This is shared across the host and device.

virtual void Alloc() {
  MandelParameters p = GetParameters();
  data_ = malloc_shared<int>(p.row_count() * p.col_count(), *q);
}

Launch the Kernel with Evaluate 

The Evaluate() member function launches the kernel code and calculates the Mandelbrot set.

Inside parallel_for(), the work item ID (index) is mapped to row and column coordinates, which are used to construct a point in the complex plane using the ScaleRow()/ScaleCol() functions. The MandelParameters Point() function is called to determine if the complex point belongs to the Mandelbrot set, with its result written to the corresponding location in shared memory.

void Evaluate(queue &q) {

  // Iterate over image and check if each point is in Mandelbrot set.

  MandelParameters p = GetParameters();

  const int rows = p.row_count();
  const int cols = p.col_count();
  auto ldata = data_;

  // Iterate over image and compute mandel for each point.

  auto e = q.parallel_for(range(rows * cols), [=](id<1> index) {
    int i = index / cols;
    int j = index % cols;
    auto c = MandelParameters::ComplexF(p.ScaleRow(i), p.ScaleCol(j));
    ldata[index] = p.Point(c);
  });

  // Wait for the asynchronous computation on device to complete.

  e.wait();
}

Free Shared Memory with Destructor

The destructor frees the shared memory by calling the Free() member function, ensuring no memory leaks in the program.

virtual void Free() { free(data_, *q); }

Other Functions

Producing a Basic Visualization of the Mandlebrot Set

The Mandel class also contains member functions for data visualization. WriteImage() generates a PNG image representation of the data, where each pixel represents a point on the complex plane, and its luminosity represents the iteration depth calculated by Point().

void WriteImage() {
  constexpr int channel_num{3};
  int row_count = p_.row_count();
  int col_count = p_.col_count();

  uint8_t *pixels = new uint8_t[col_count * row_count * channel_num];

  int index = 0;

  for (int j = 0; j < row_count; ++j) {
    for (int i = 0; i < col_count; ++i) {
      float normalized = (1.0 * data_[i * col_count + j]) / max_iterations;
      int color = int(normalized * 0xFFFFFF);  // 16M color.

      int r = (color >> 16) & 0xFF;
      int g = (color >> 8) & 0xFF;
      int b = color & 0xFF;

      pixels[index++] = r;
      pixels[index++] = g;
      pixels[index++] = b;
    }
  }

  stbi_write_png("mandelbrot.png", row_count, col_count, channel_num, pixels,
                 col_count * channel_num);

  delete[] pixels;
}

Example Image of Data Output

The Mandel class’s Print()member function produces a similar visualization as is written to stdout.

Summary

This walkthrough demonstrates how you can use familiar C/C++ patterns to manage data within host and device memory, using Mandlebrot as a test case.