Trying to understand how to implement a mass spring setup (cloth sim) with a compute shader

Started by
2 comments, last by Sprue 10 months, 3 weeks ago

I am trying to implement a mass spring system using a compute shader and have run into a weird issue.

Essentially the way I have setup the system is using three separate compute shaders, each with their own respective purpose.

The first is as follows:

struct Vertex
{
    float3 position;
    float3 normal;
    float2 texCoord;
    float4 tangent;
};

struct InterlockedVector
{
    int x;
    int y;
    int z;
};

// skinning that has been transferred onto the simulation mesh
StructuredBuffer<Vertex> simMeshSkinnedVertexBuffer : register(t0);
// the skinning data from the previous frame
RWStructuredBuffer<Vertex> simMeshPreviousSkinnedVertexBuffer : register(u0);
// the output result stored as integers InterlockedAdd can be used to manage synchronisation
RWStructuredBuffer<InterlockedVector> simMeshTransformedVertexBuffer : register(u1);

float3 UnQuantize(InterlockedVector input, float factor)
{
    float vertexPositionX = ((float) input.x) / factor;
    float vertexPositionY = ((float) input.y) / factor;
    float vertexPositionZ = ((float) input.z) / factor;
    return float3(vertexPositionX, vertexPositionY, vertexPositionZ);
}

// Define the compute shader entry point
[numthreads(64, 1, 1)]
void CS(uint3 dispatchThreadID : SV_DispatchThreadID)
{
    // Get the current vertex ID
    uint simMeshVertexID = dispatchThreadID.x;
    
    // Get the change in driving force from the relative transformation between frames
    float3 simSkinForce = simMeshSkinnedVertexBuffer[simMeshVertexID].position - simMeshPreviousSkinnedVertexBuffer[simMeshVertexID].position;

    // convert data to integer
    int quantizedX = (int) (simSkinForce.x * QUANTIZE);
    int quantizedY = (int) (simSkinForce.y * QUANTIZE);
    int quantizedZ = (int) (simSkinForce.z * QUANTIZE);
 
    // perform atomic addition
    InterlockedAdd(simMeshTransformedVertexBuffer[simMeshVertexID].x, quantizedX);
    InterlockedAdd(simMeshTransformedVertexBuffer[simMeshVertexID].y, quantizedY);
    InterlockedAdd(simMeshTransformedVertexBuffer[simMeshVertexID].z, quantizedZ);
    
    // store skinning as previous
    simMeshPreviousSkinnedVertexBuffer[simMeshVertexID] = simMeshSkinnedVertexBuffer[simMeshVertexID];
}

The purpose of this part of the code is to feed in the changes to the skinned mesh as a driving force in the mass spring simulation. This part seems to working well from what I can see.

I am using InterlockedAdd as I had a suspicion that I was having some synchronization issues.

This Pre Solve compute shader is dispatched like this:

const UINT threadGroupSizeX = 64;
const UINT threadGroupSizeY = 1;
const UINT threadGroupSizeZ = 1;
cmdList->Dispatch((ri->SimMeshVertexCount + threadGroupSizeX - 1) / threadGroupSizeX, threadGroupSizeY, threadGroupSizeZ);

The next part is were I am running into some issues.

struct Neighbours
{
    uint index[8];
};

struct RestConstraint
{
    float length[8];
};

struct InterlockedVector
{
    int x;
    int y;
    int z;
};

// The rest length of each neighbouring vertex
StructuredBuffer<RestConstraint> restConstraintBuffer : register(t0);
// The vertexID of each neighbouring vertex
StructuredBuffer<Neighbours> simMeshVertexAdjacencyBuffer : register(t1);
// The output result stored as integers InterlockedAdd can be used to manage synchronisation
RWStructuredBuffer<InterlockedVector> simMeshTransformedVertexBuffer : register(u0);

float3 UnQuantize(InterlockedVector input, float factor)
{
    float vertexPositionX = ((float) input.x) / factor;
    float vertexPositionY = ((float) input.y) / factor;
    float vertexPositionZ = ((float) input.z) / factor;
    return float3(vertexPositionX, vertexPositionY, vertexPositionZ);
}

// Iterating over every vertex in the simulation mesh
[numthreads(1, 1, 1)]
void CS(uint3 dispatchThreadID : SV_DispatchThreadID)
{
    int springIterations = 1;
    int neighbourCount = 8;
    uint vertexID = dispatchThreadID.x;
    
    // For the number of iterations
    for (int iter = 0; iter < springIterations; ++iter)
    {
        // For each neighbour of the base vertex
        for (int ni = 0; ni < neighbourCount; ++ni)
        {
            // Get the neighbours vertex ID & rest length
            uint neighbour = simMeshVertexAdjacencyBuffer[vertexID].index[ni];
            float neighbourLength = restConstraintBuffer[vertexID].length[ni];
            
            // Get the current position of the base vertex
            float3 vertexPosition = UnQuantize(simMeshTransformedVertexBuffer[vertexID], QUANTIZE);
        
            // if the neighbours vertexID is not the same as the base vertexID (This is used to fill the Neighbours values and provide and mechanism to identify them as void entries)
            if (neighbour != vertexID)
            {
                // Get the current position of the neighbour vertex
                float3 neighbourVertexPosition = UnQuantize(simMeshTransformedVertexBuffer[neighbour], QUANTIZE);

                // Calculate the displacement between the base vertex and neighbour vertex
                float3 neighbourSkinnedDisplacement = vertexPosition - neighbourVertexPosition;
                
                // Find the length of this vector
                float neighbourSkinnedLength = length(neighbourSkinnedDisplacement);
                
                // Calculate scale and correction vector
                float3 correctionVector = (neighbourSkinnedDisplacement * (1.0 - neighbourLength / neighbourSkinnedLength)) * 0.5;
                
                // convert the result to integers ( + & - )
                int quantizedX = (int) (correctionVector.x * QUANTIZE);
                int quantizedY = (int) (correctionVector.y * QUANTIZE);
                int quantizedZ = (int) (correctionVector.z * QUANTIZE);
                int invQuantizedX = (int) (-correctionVector.x * QUANTIZE);
                int invQuantizedY = (int) (-correctionVector.y * QUANTIZE);
                int invQuantizedZ = (int) (-correctionVector.z * QUANTIZE);
 
                // Offset the base and neighbour vertex by half of the correction vertex in opposing directions (spring satisfaction)
                InterlockedAdd(simMeshTransformedVertexBuffer[vertexID].x, invQuantizedX);
                InterlockedAdd(simMeshTransformedVertexBuffer[vertexID].y, invQuantizedY);
                InterlockedAdd(simMeshTransformedVertexBuffer[vertexID].z, invQuantizedZ);
                InterlockedAdd(simMeshTransformedVertexBuffer[neighbour].x, quantizedX);
                InterlockedAdd(simMeshTransformedVertexBuffer[neighbour].y, quantizedY);
                InterlockedAdd(simMeshTransformedVertexBuffer[neighbour].z, quantizedZ);
            }
        }
    }
}

As you can see this is a relatively standard mass spring setup, I know a lot of mass spring solvers don't offset the base and neighbor vertices simultaneously, but I don't have any double up constraints that are palindromes, as it is less efficient.

This code seems to be working reasonably well, the main issue, is that it only seems to run correctly with 1 thread... I assume that there is something that I am not doing correctly to ensure that this is working properly in parallel. I hoped that the InterlockedAdd function would ensure that the code would synchronized properly, but that does not seem to be the case :(

The mass spring compute shader is dispatched like this:

const UINT threadGroupSizeX = 1;
const UINT threadGroupSizeY = 1;
const UINT threadGroupSizeZ = 1;
cmdList->Dispatch((ri->SimMeshVertexCount + threadGroupSizeX - 1) / threadGroupSizeX, threadGroupSizeY, threadGroupSizeZ);

Any help or pointers would be appreciated greatly! Im a bit lost as to what I should try next

Reference I used for the mass spring setup: https://viscomp.alexandra.dk/index2fa7.html?p=147

Full implementation: https://github.com/JChittockDev/Research/blob/main/OpenResearchEngine/Render/Manager/Render.cpp

Advertisement

Gregm8 said:
This code seems to be working reasonably well, the main issue, is that it only seems to run correctly with 1 thread... I assume that there is something that I am not doing correctly to ensure that this is working properly in parallel. I hoped that the InterlockedAdd function would ensure that the code would synchronized properly, but that does not seem to be the case :(

I think you overlook that your atomics only avoid messing up the individual components of a 3D point, but multiple threads (and also workgroups as well) will still concurrently change the x,y,z components without synchronization, which can cause jitter still even if you only use one thread per workgroup. To fix this in the way you intend, you would need something like a spinlock, so only one thread at a time can modify a single 3D point. (Or pack all 3 into a 64 bit value, since most GPUs support 64bit atomics)

However, even if you did this, then you still have the problem that you not only write but also read from the same buffer as far as i can see, and the spinlock for the write would be not enough to prevent random reads from modified or initial states.

So that's all very bad, and you need to do it differently all together.
The first issue is actually the use of quantized integers. Simulation on quantized data can work ofc., but it will be lower quality, and you pay a very high cost for the atomics and the quantization. Neither is necessary or benefitial.

Your current approach is a kind of scattering approach. In general, but especially with parallel programming, we can assume a gathering approach will be usually faster.
You also usually need two buffers. One for the initial state, a second to store the new and modified state. That's how we avoid a need for atomics. (Atomics are fast only for LDS (or ‘shared’) memory, but slow for VRAM)

Some pseudo code for the cloth constraints:

vec3 buffer0[NUM_CLOTH_VERTICES]; // current positions; in VRAM
vec3 buffer1[NUM_CLOTH_VERTICES]; // future positions we are about to calculate
ConstraintsInfo infos[NUM_CLOTH_VERTICES];
parallel_for(int i=0; i<NUM_CLOTH_VERTICES; i++)
{
	vec3 correction(0);
	for (int a=0; a<infos[i].adjacencyCount; a++) // we gather all neighbors of vertex i
	{
		correction += CalcCorrection(infos[i].restDistance, buffer0[i], buffer0[infos[i].adjacentIndex[a]]; // adding the dispalcement the adjacency constraint causes to the current vertex i
	}
	correction *= 0.3f; // some damping to avoid solver explosion, may need to be smaller
	
	buffer1[i] = buffer0[i] + correction; // notice only this thread will write to vertex i, becasue only this thread processes it; all ather memary access we did only reads from buffer0, which we do not modify
}

After this you can just swap buffers 0 and 1 for the next frame, or for the next iteration of an iterative solver.

You see this is much simpler and avoids all potential write hazards or performance issues. There is no need for sync at all.
But we need to be careful to ensure vertex i to j calculates the same correction as vertex j to i.

It's interesting to compare this order independent and parallelizable approach with an alternative we could do if we were single threaded.
The alternative does not need the second buffer. We could just modify vertices as we solve one constraint after another.
Usually, the alternative converges faster, so we need less iterations to get some desired minimal error, e.g. just 8 instead 10 for the parallel version.
But that's the price we need to pay with parallel programming. Parallel algorithms almost always do more work than a a traditional serial approach.

What ^ @joej said. But I'll add a personal axiom that if it's atomic and it's not an index or a counter … you probably need to think again. It applies to all atomics but Compared (min/max) atomics (like an atomic ray-hit distance) are the more obvious example of how atomics can become problems in hiding when they are bridged between separate points in time, you were the min then - but that's not assured now after that atomic min and other things have happened, it's not even assured if you check again. What you really want is a mutex/critical-section enclosing those whole sections like we have in CPU land, we don't have that in GPU land (CUDA has some weirdness that IIRC sorta works but is totally a hack that could break any day).

When you genuinely encounter this problem and R/W ping-pong isn't an option, you can follow the axiom of an index is a safe atomic and use that to build a sequence of messages that another compute pass will resolve with another source of default min/max data to use as a singular atomic chokepoint for the transfer from event to stored result. In practice, X == 0 usually just concludes this with some memory barriers no different than your standard prefix-sum if the problem is small (smaller than a threadgroup) - if it's big … ping-pong or multipass it is.

This topic is closed to new replies.

Advertisement