You’ve probably seen a CUDA tutorial like this one — a classic “Hello World” blending CPU and GPU code in a single “heterogeneous” CUDA C++ source file, with the kernel launched using NVCC’s now-iconic triple-bracket <<<>>> syntax:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
#include <cuda_runtime.h>
#include <stdio.h>

__global__ void kernel() {
    printf("Hello World from block %d, thread %d\n", blockIdx.x, threadIdx.x);
}

int main() {
    kernel<<<1, 1>>>(); // Returns `void`?! 🤬    
    return cudaDeviceSynchronize() == cudaSuccess ? 0 : -1;
}

I still see this exact pattern in production code — and I’ll admit, it shows up in some of my own toy projects too - one, two, and three. But relying on triple-bracket kernel launches in production isn’t ideal. They don’t return error codes, and they encourage a false sense of simplicity. So in the next ~25 KBytes of text, we’ll explore the less wrong ways to launch kernels.

Old School Graphics

This post doesn’t teach you how to write CUDA kernels. It won’t cover semaphores, event graphs, or full-scale schedulers. Instead, it focuses on one thing: launching a single CUDA kernel correctly.

Basics and Correctness

The snippet above compiles and runs with the expected output:

1
2
$ nvcc -o hello_world hello_world.cu && ./hello_world
> Hello World from block 0, thread 0

In some sense, it’s already “correct”.

However, today, most GPUs come packed eight per HGX board, and at the very least, you’d want to introduce some parallelism to ensure kernels are launched across all of them.
Basic utilization aside, a single DGX H100 node comes with two beefy CPUs, constantly juggling complex execution graphs and moving hundreds of gigabytes per second in both directions — all asynchronously.

A lot can go wrong in such a system, so it’s worth establishing a few basic ground rules for how GPU kernels should be orchestrated:

  • Kernel launches are high-latency and should be asynchronous.
  • Work should be explicitly ordered within streams.
  • CUDA API calls and kernel launches must be paired with robust error checks.

Here’s how to integrate CUDA streams — nothing fancy:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
#include <cuda_runtime.h>
#include <stdio.h>

__global__ void kernel() {
    extern __shared__ char shared_buffer[];
    printf("Hello World from block %d, thread %d\n", blockIdx.x, threadIdx.x);
}

int main() {
    cudaStream_t stream;
    cudaStreamCreate(&stream);
    uint shared_memory_size = 0;
    kernel<<<1, 1, shared_memory_size, stream>>>(); // 4 arguments, not 2
    cudaStreamSynchronize(stream);
    cudaStreamDestroy(stream);
    return 0;
}

Note the four-argument kernel launch.


Here’s a more careful version with explicit error handling:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <cuda_runtime.h>
#include <stdio.h>

__global__ void kernel() {
    extern __shared__ char shared_buffer[];
    printf("Hello World from block %d, thread %d\n", blockIdx.x, threadIdx.x);
}

int main() {
    cudaStream_t stream;
    cudaError_t err = cudaStreamCreate(&stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to create stream: %s\n", cudaGetErrorString(err));
        return -1;
    }
    uint shared_memory_size = 1 << 30; // 1 GB intentionally large for demonstration
    kernel<<<1, 1, shared_memory_size, stream>>>();
    err = cudaStreamSynchronize(stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to synchronize stream: %s\n", cudaGetErrorString(err));
        cudaStreamDestroy(stream);
        return -1;
    }
    cudaStreamDestroy(stream);
    return 0;
}

This is where most tutorials stop — and silently fail. Try running it:

1
$ nvcc -o hello_world hello_world.cu && ./hello_world

No output. No error message. Nothing. We can’t fetch an error from the stream because the failure happens on submission, not execution.

CUDA Runtime API

NVCC’s triple-bracket syntax is syntactic sugar over the CUDA Runtime API, itself a wrapper over the lower-level CUDA Driver API. To catch errors effectively, we must use the CUDA Driver’s Execution Control API:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <cuda_runtime.h>
#include <stdio.h>

__global__ void kernel() {
    extern __shared__ char shared_buffer[];
    printf("Hello World from block %d, thread %d\n", blockIdx.x, threadIdx.x);
}

int main() {
    cudaStream_t stream;
    cudaError_t err;
    err = cudaStreamCreate(&stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to create stream: %s\n", cudaGetErrorString(err));
        return -1;
    }

    dim3 grid(1);
    dim3 block(1);
    size_t shared_memory_size = 1 << 30; // 1 GB
    void *kernel_args[] = {};
    err = cudaLaunchKernel((void *)kernel, grid, block, kernel_args, shared_memory_size, stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to launch kernel: %s\n", cudaGetErrorString(err));
        cudaStreamDestroy(stream);
        return -1;
    }

    err = cudaStreamSynchronize(stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Kernel execution failed: %s\n", cudaGetErrorString(err));
        cudaStreamDestroy(stream);
        return -1;
    }
    err = cudaStreamDestroy(stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to destroy stream: %s\n", cudaGetErrorString(err));
        return -1;
    }
    return 0;
}

Compile and run it:

1
2
$ nvcc -o hello_world hello_world.cu && ./hello_world
> Failed to launch kernel: invalid argument

This error is expected because the kernel can’t even be submitted due to an absurd memory request. However, the issue with this API is that we need a different way of passing arguments to the kernel. Here is how we would pass arrays and scalars to the kernel:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cuda_runtime.h>
#include <stdio.h>

__global__ void kernel(float *amount, size_t count, int power) {
    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx > count) return;
    amount[idx] = amount[idx] * scalbln(1.0, power); // An example of a CUDA intrinsic ;)
}

int main() {
    cudaError_t err;
    size_t num_elements = 1024;
    int integral_power = -2;
    double *data;

    // Allocate unified memory
    err = cudaMallocManaged(&data, num_elements * sizeof(double));
    if (err != cudaSuccess) {
        fprintf(stderr, "cudaMallocManaged failed: %s\n", cudaGetErrorString(err));
        return -1;
    }

    // Initialize data
    for (size_t i = 0; i < num_elements; ++i) data[i] = (double)i;

    // Create CUDA stream
    cudaStream_t stream;
    err = cudaStreamCreate(&stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to create stream: %s\n", cudaGetErrorString(err));
        cudaFree(data);
        return -1;
    }

    // Define kernel launch parameters
    dim3 grid((num_elements + 255) / 256);
    dim3 block(256);
    void *kernel_args[] = {
        (void *)&data,
        (void *)&num_elements,
        (void *)&integral_power,
    };

    // Launch kernel
    err = cudaLaunchKernel((void *)kernel, grid, block, kernel_args, 0, stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to launch kernel: %s\n", cudaGetErrorString(err));
        cudaStreamDestroy(stream);
        cudaFree(data);
        return -1;
    }

    // Synchronize stream
    err = cudaStreamSynchronize(stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Kernel execution failed: %s\n", cudaGetErrorString(err));
        cudaStreamDestroy(stream);
        cudaFree(data);
        return -1;
    }

    // Print results
    for (size_t i = 0; i < 5; ++i) printf("data[%zu] = %f\n", i, data[i]);
    cudaStreamDestroy(stream);
    cudaFree(data);
    return 0;
}

I used unified memory to make the example simpler. We don’t have to explicitly allocate 2 buffers on the CPU and GPU and copy data between them. The driver maintains copies of the data in the host and device memory and automatically transfers updates between them when needed.

Cooperative Groups

It’s long been a dream that we could write parallel GPU algorithms once — stack a few abstractions, wrap them in templates, and let the runtime figure things out. In practice, that rarely works out. Unfortunately, the CUDA Cooperative Groups API is no exception.


It was designed as a unified abstraction for coordinating threads beyond a single block — using C++ intrinsics to schedule complex GPU algorithms with more flexible synchronization semantics. In theory, this should solve a significant problem: letting all threads on the device synchronize before progressing, which is essential for iterative algorithms like physics simulations or solvers.

As a reminder:

  • __syncwarp() for a warp of 32 threads.
  • __syncthreads() for a logical block of 1-1024 threads.
  • For everything else, there’s Mastercard Cooperative Groups.

The most obvious example would be synchronizing the whole grid in multi-step iterative algorithms, like Physics simulations. For that, Nvidia recommends using the new cooperative_groups::sync() function:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <stdio.h>
#include <math.h>

namespace cg = cooperative_groups;

__device__ float3 compute_force(float3 position_first, float3 position_second) {
    float3 r;
    r.x = position_second.x - position_first.x;
    r.y = position_second.y - position_first.y;
    r.z = position_second.z - position_first.z;

    float squared_distance = r.x * r.x + r.y * r.y + r.z * r.z + 1e-6f; // avoid div by zero
    float reciprocal_distance = rsqrtf(squared_distance);
    float reciprocal_cube = reciprocal_distance * reciprocal_distance * reciprocal_distance;

    constexpr float gravitational_constant = 1.0f;
    float scale = gravitational_constant * reciprocal_cube;
    r.x *= scale;
    r.y *= scale;
    r.z *= scale;
    return r;
}

__global__ void cooperative_kernel(
    float3 *positions_old, float3 *positions_new,
    float3 *velocities_old, float3 *velocities_new,
    size_t count, size_t iterations, float dt) {

    cg::grid_group grid = cg::this_grid();
    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= count) return;

    for (size_t iter = 0; iter < iterations; ++iter) {
        float3 force = {0.0f, 0.0f, 0.0f};

        // Accumulate forces from all other particles
        for (size_t j = 0; j < count; ++j) {
            if (j == idx) continue;
            float3 f = compute_force(positions_old[idx], positions_old[j]);
            force.x += f.x;
            force.y += f.y;
            force.z += f.z;
        }

        // Update velocity and position
        velocities_new[idx].x = velocities_old[idx].x + force.x * dt;
        velocities_new[idx].y = velocities_old[idx].y + force.y * dt;
        velocities_new[idx].z = velocities_old[idx].z + force.z * dt;
        positions_new[idx].x = positions_old[idx].x + velocities_new[idx].x * dt;
        positions_new[idx].y = positions_old[idx].y + velocities_new[idx].y * dt;
        positions_new[idx].z = positions_old[idx].z + velocities_new[idx].z * dt;

        // Swap buffers for the next iteration
        grid.sync();
        float3 *temp_pos = positions_old, *temp_vel = velocities_old;
        positions_old = positions_new, positions_new = temp_pos;
        velocities_old = velocities_new, velocities_new = temp_vel;
        grid.sync();
    }
}

int main() {
    cudaError_t err;
    size_t num_particles = 256;
    size_t iterations = 10;
    float dt = 0.01f;
    dim3 block;
    dim3 grid;
    void *kernel_args[7];
    float3 *positions_old = nullptr, *positions_new = nullptr;
    float3 *velocities_old = nullptr, *velocities_new = nullptr;

    // Allocate memory
    err = cudaMallocManaged(&positions_old, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;
    err = cudaMallocManaged(&positions_new, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;
    err = cudaMallocManaged(&velocities_old, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;
    err = cudaMallocManaged(&velocities_new, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;

    // Initialize positions and velocities
    for (size_t i = 0; i < num_particles; ++i) {
        float theta = (float)i * 0.01f;
        float phi = (float)i * 0.005f;
        float radius = 10.0f + (i % 32) * 0.1f;
        positions_old[i] = {radius * cosf(theta) * sinf(phi), radius * sinf(theta) * sinf(phi), radius * cosf(phi)};
        velocities_old[i] = {0.01f * sinf(phi), 0.01f * cosf(theta), 0.01f * sinf(theta + phi)};
    }

    // Make sure the device supports cooperative launch
    cudaDeviceProp props;
    cudaGetDeviceProperties(&props, 0);
    if (!props.cooperativeLaunch) {
        fprintf(stderr, "Cooperative launch not supported on this device.\n");
        err = cudaErrorNotSupported;
        goto cleanup;
    }
    cudaStream_t stream;
    err = cudaStreamCreate(&stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to create stream: %s\n", cudaGetErrorString(err));
        goto cleanup;
    }

    block = dim3(256);
    grid = dim3((num_particles + block.x - 1) / block.x);
    kernel_args[0] = &positions_old;
    kernel_args[1] = &positions_new;
    kernel_args[2] = &velocities_old;
    kernel_args[3] = &velocities_new;
    kernel_args[4] = &num_particles;
    kernel_args[5] = &iterations;
    kernel_args[6] = &dt;

    // Launch the kernel
    err = cudaLaunchCooperativeKernel((void *)cooperative_kernel, grid, block, kernel_args, 0, stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to launch cooperative kernel: %s\n", cudaGetErrorString(err));
        goto cleanup;
    }
    err = cudaStreamSynchronize(stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Kernel execution failed: %s\n", cudaGetErrorString(err));
        goto cleanup;
    }

    // Print final positions
    for (size_t i = 0; i < num_particles; ++i)
        printf("Final position[%zu] = (%f, %f, %f)\n", i, positions_old[i].x, positions_old[i].y, positions_old[i].z);

cleanup:
    if (positions_old) cudaFree(positions_old);
    if (positions_new) cudaFree(positions_new);
    if (velocities_old) cudaFree(velocities_old);
    if (velocities_new) cudaFree(velocities_new);
    return (err == cudaSuccess) ? 0 : -1;
}

Notice how I’ve replaced cudaLaunchKernel with cudaLaunchCooperativeKernel and added a cg::grid_group object to the kernel. If we were to use the old non-cooperative cudaLaunchKernel launch API, we would get:

1
2
$ nvcc -o hello_world hello_world.cu && ./hello_world
> Kernel execution failed: unspecified launch failure

So, we need to use the new “cooperative” API, which initially looked promising. With faster GPU-GPU interconnects and in-node NVLink switches, I hoped we’d see more robust synchronization primitives for multi-GPU systems. For a moment, that future felt near: the CUDA runtime introduced cudaLaunchCooperativeKernelMultiDevice and the cg::multi_grid_group abstraction — the missing pieces for coordinating kernels across multiple GPUs. However, both were deprecated in CUDA 11.3, making them some of the shortest-lived APIs in CUDA history.

While Cooperative Groups are attractive in concept, I don’t see them scaling up meaningfully. Most of my work happens at a warp level with __syncwarp(), and if I need to go higher, I generally prefer to write inline PTX assembly. We can quite easily check what the cooperative_groups::sync() function compiles to by using the -ptx flag with NVCC:

1
2
$ nvcc -arch=sm_80 -ptx -o hello_world.ptx hello_world.cu
$ grep -A 1 "barrier.sync" hello_world.ptx

This shows us what we already suspected: under the hood, it’s just a barrier.sync instruction. So, if you’re comfortable with inline PTX, you can recreate the same behavior without the <cooperative_groups.h> header. And if you’re writing performance-critical code, it’s not just cleaner — it’s more transparent and debuggable.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <cuda_runtime.h>
#include <stdio.h>
#include <math.h>

__device__ inline void grid_sync_ptx() { asm volatile("barrier.sync 0;" ::); }

__device__ float3 compute_force(float3 position_first, float3 position_second) {
    float3 r;
    r.x = position_second.x - position_first.x;
    r.y = position_second.y - position_first.y;
    r.z = position_second.z - position_first.z;

    float squared_distance = r.x * r.x + r.y * r.y + r.z * r.z + 1e-6f; // avoid div by zero
    float reciprocal_distance = rsqrtf(squared_distance);
    float reciprocal_cube = reciprocal_distance * reciprocal_distance * reciprocal_distance;

    constexpr float gravitational_constant = 1.0f;
    float scale = gravitational_constant * reciprocal_cube;
    r.x *= scale;
    r.y *= scale;
    r.z *= scale;
    return r;
}

__global__ void cooperative_kernel(float3 *positions_old, float3 *positions_new, float3 *velocities_old, float3 *velocities_new, size_t count,
                                   size_t iterations, float dt) {
    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= count) return;

    for (size_t iter = 0; iter < iterations; ++iter) {
        float3 force = {0.0f, 0.0f, 0.0f};

        // Accumulate forces from all other particles
        for (size_t j = 0; j < count; ++j) {
            if (j == idx) continue;
            float3 f = compute_force(positions_old[idx], positions_old[j]);
            force.x += f.x;
            force.y += f.y;
            force.z += f.z;
        }

        // Update velocity and position
        velocities_new[idx].x = velocities_old[idx].x + force.x * dt;
        velocities_new[idx].y = velocities_old[idx].y + force.y * dt;
        velocities_new[idx].z = velocities_old[idx].z + force.z * dt;

        positions_new[idx].x = positions_old[idx].x + velocities_new[idx].x * dt;
        positions_new[idx].y = positions_old[idx].y + velocities_new[idx].y * dt;
        positions_new[idx].z = positions_old[idx].z + velocities_new[idx].z * dt;

        grid_sync_ptx();

        // Swap buffers for the next iteration
        float3 *temp_pos = positions_old, *temp_vel = velocities_old;
        positions_old = positions_new, positions_new = temp_pos;
        velocities_old = velocities_new, velocities_new = temp_vel;

        grid_sync_ptx();
    }
}

int main() {
    cudaError_t err;
    size_t num_particles = 256;
    size_t iterations = 10;
    float dt = 0.01f;

    float3 *positions_old = nullptr, *positions_new = nullptr;
    float3 *velocities_old = nullptr, *velocities_new = nullptr;

    dim3 block;
    dim3 grid;
    void *kernel_args[7];

    // Allocate memory
    err = cudaMallocManaged(&positions_old, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;
    err = cudaMallocManaged(&positions_new, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;
    err = cudaMallocManaged(&velocities_old, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;
    err = cudaMallocManaged(&velocities_new, num_particles * sizeof(float3));
    if (err != cudaSuccess) goto cleanup;

    for (size_t i = 0; i < num_particles; ++i) {
        float theta = (float)i * 0.01f;
        float phi = (float)i * 0.005f;
        float radius = 10.0f + (i % 32) * 0.1f;
        positions_old[i] = {radius * cosf(theta) * sinf(phi), radius * sinf(theta) * sinf(phi), radius * cosf(phi)};
        velocities_old[i] = {0.01f * sinf(phi), 0.01f * cosf(theta), 0.01f * sinf(theta + phi)};
    }

    cudaDeviceProp props;
    cudaGetDeviceProperties(&props, 0);
    if (!props.cooperativeLaunch) {
        fprintf(stderr, "Cooperative launch not supported on this device.\n");
        err = cudaErrorNotSupported;
        goto cleanup;
    }

    cudaStream_t stream;
    err = cudaStreamCreate(&stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to create stream: %s\n", cudaGetErrorString(err));
        goto cleanup;
    }

    block = dim3(256);
    grid = dim3((num_particles + block.x - 1) / block.x);
    kernel_args[0] = &positions_old;
    kernel_args[1] = &positions_new;
    kernel_args[2] = &velocities_old;
    kernel_args[3] = &velocities_new;
    kernel_args[4] = &num_particles;
    kernel_args[5] = &iterations;
    kernel_args[6] = &dt;

    err = cudaLaunchKernel((void *)cooperative_kernel, grid, block, kernel_args, 0, stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Failed to launch cooperative kernel: %s\n", cudaGetErrorString(err));
        goto cleanup;
    }

    err = cudaStreamSynchronize(stream);
    if (err != cudaSuccess) {
        fprintf(stderr, "Kernel execution failed: %s\n", cudaGetErrorString(err));
        goto cleanup;
    }

    for (size_t i = 0; i < num_particles; ++i)
        printf("Final position[%zu] = (%f, %f, %f)\n", i, positions_old[i].x, positions_old[i].y, positions_old[i].z);

cleanup:
    if (positions_old) cudaFree(positions_old);
    if (positions_new) cudaFree(positions_new);
    if (velocities_old) cudaFree(velocities_old);
    if (velocities_new) cudaFree(velocities_new);
    return (err == cudaSuccess) ? 0 : -1;
}

And yes, I’ve launched it with the old-school cudaLaunchKernel rather than the cudaLaunchCooperativeKernel while no-one was watching. No runtime complaints — as long as your device supports barrier.sync. PTX has a whole menu of other barriers if you’re feeling adventurous.

CUDA Driver API

Lastly, insufficient attention is paid to the lower-level CUDA Driver API — which can be extremely useful in production. Sure, it’s a bit more verbose, but it gives you complete control over kernel loading and launching, including support for dynamically loading PTX, CUBINs, or SASS at runtime. Going the extra mile, I recommend separating the kernel code from the host code entirely, using separate compilers for each, and introducing a stable ABI between them. Here is what our vanilla C99 host code may look like:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
#include <cuda.h>
#include <stdio.h>
#include <math.h>

#define CUDA_CHECK(err)                             \
    if (err != CUDA_SUCCESS) {                      \
        const char *msg;                            \
        cuGetErrorString(err, &msg);                \
        fprintf(stderr, "CUDA error: %s\n", msg);   \
        goto cleanup;                               \
    }

int main() {
    CUresult err;
    size_t num_particles = 256;
    size_t iterations = 10;
    float dt = 0.01f;

    CUdevice device;
    CUcontext context = NULL;
    CUmodule module = NULL;
    CUfunction kernel;
    CUstream stream = NULL;
    float *positions_old = NULL, *positions_new = NULL;
    float *velocities_old = NULL, *velocities_new = NULL;
    void *kernel_args[7];

    // Initialize CUDA
    err = cuInit(0);
    CUDA_CHECK(err);
    err = cuDeviceGet(&device, 0);
    CUDA_CHECK(err);
    err = cuCtxCreate(&context, 0, device);
    CUDA_CHECK(err);
    err = cuStreamCreate(&stream, CU_STREAM_DEFAULT);
    CUDA_CHECK(err);

    // Load the "bytecode" PTX, that will later be JIT-compiled to SASS
    err = cuModuleLoad(&module, "hello_world.ptx");
    CUDA_CHECK(err);
    err = cuModuleGetFunction(&kernel, module, "cooperative_kernel");
    CUDA_CHECK(err);

    // Allocate managed memory for positions and velocities
    size_t buffer_size = num_particles * sizeof(float) * 3;
    err = cuMemAllocManaged((CUdeviceptr *)&positions_old, buffer_size, CU_MEM_ATTACH_GLOBAL);
    CUDA_CHECK(err);
    err = cuMemAllocManaged((CUdeviceptr *)&positions_new, buffer_size, CU_MEM_ATTACH_GLOBAL);
    CUDA_CHECK(err);
    err = cuMemAllocManaged((CUdeviceptr *)&velocities_old, buffer_size, CU_MEM_ATTACH_GLOBAL);
    CUDA_CHECK(err);
    err = cuMemAllocManaged((CUdeviceptr *)&velocities_new, buffer_size, CU_MEM_ATTACH_GLOBAL);
    CUDA_CHECK(err);

    // Initialize positions and velocities
    for (size_t i = 0; i < num_particles; ++i) {
        float theta = (float)i * 0.01f;
        float phi = (float)i * 0.005f;
        float radius = 10.0f + (i % 32) * 0.1f;
        positions_old[3 * i + 0] = radius * cosf(theta) * sinf(phi);
        positions_old[3 * i + 1] = radius * sinf(theta) * sinf(phi);
        positions_old[3 * i + 2] = radius * cosf(phi);
        velocities_old[3 * i + 0] = 0.01f * sinf(phi);
        velocities_old[3 * i + 1] = 0.01f * cosf(theta);
        velocities_old[3 * i + 2] = 0.01f * sinf(theta + phi);
    }

    kernel_args[0] = &positions_old;
    kernel_args[1] = &positions_new;
    kernel_args[2] = &velocities_old;
    kernel_args[3] = &velocities_new;
    kernel_args[4] = &num_particles;
    kernel_args[5] = &iterations;
    kernel_args[6] = &dt;

    // Launch the kernel
    int threads_per_block = 256;
    int blocks_per_grid = (num_particles + threads_per_block - 1) / threads_per_block;
    err = cuLaunchKernel(kernel,
                         blocks_per_grid, 1, 1,
                         threads_per_block, 1, 1,
                         0, stream,
                         kernel_args, NULL);
    CUDA_CHECK(err);
    err = cuStreamSynchronize(stream);
    CUDA_CHECK(err);

    // Log the final positions
    for (size_t i = 0; i < num_particles; ++i)
        printf("Final position[%zu] = (%f, %f, %f)\n", i,
               positions_old[3 * i + 0],
               positions_old[3 * i + 1],
               positions_old[3 * i + 2]);

cleanup:
    if (stream) cuStreamDestroy(stream);
    if (positions_old) cuMemFree((CUdeviceptr)positions_old);
    if (positions_new) cuMemFree((CUdeviceptr)positions_new);
    if (velocities_old) cuMemFree((CUdeviceptr)velocities_old);
    if (velocities_new) cuMemFree((CUdeviceptr)velocities_new);
    if (module) cuModuleUnload(module);
    if (context) cuCtxDestroy(context);
    return (err == CUDA_SUCCESS) ? 0 : -1;
}

The only thing you need to watch out for is name mangling. To make this work, ensure the kernel declaration in your .cu file is wrapped with extern "C":

1
extern "C" __global__ void cooperative_kernel(...);

With that in place, you can compile the GPU side to PTX using NVCC and the host side using GCC — entirely independently:

1
2
3
4
5
$ nvcc -arch=sm_80 -ptx -o hello_world.ptx hello_world.cu
$ gcc -o hello_world hello_world.c \
    -I/usr/local/cuda/include \
    -L/usr/local/cuda/lib64 \
    -lcuda -lm && ./hello_world

Conclusion

This is, of course, much more verbose than the original <<<1, 1>>> example — but doing things right usually is. That said, plenty of tools have popped up in recent years to simplify prototyping CUDA code, including the many DSLs and compilers NVIDIA just showcased at the last GTC. But the way we ship and launch production kernels has remained remarkably stable for over a decade — almost identical to what you’d do with OpenCL or CUDA circa 2010:

Don’t judge too harshly on slide 35. Memory ordering and control flow on one slide was probably a bad idea — and part of the motivation for writing this post.

What has changed is the complexity inside the kernels - they’re now less data-parallel and more like concurrent CPU algorithms — with atomics, warp-level reductions, and specialized Tensor Core logic baked in. Those look and feel different on every generation of GPUs and porting my other libraries to Blackwell has made that painfully clear. So expect a few more posts soon 😉