#include "data\\shaders\\input_formats.h"
#include "data\\shaders\\pbd\\pbd_common.h"

StructuredBuffer<Particle> in_particles : register( t0 );
StructuredBuffer<DistConstraint> constraints : register(t1);
StructuredBuffer<uint> constraints_inds : register(t2);
RWStructuredBuffer<Particle> out_particles : register(u0);

cbuffer PerFrame : register(b0)
{
  double delta_time;
  uint max_particles;
  uint max_constraints;
}

[numthreads(8,1,1)]
void main(uint3 dispatch_thread_id  : SV_DispatchThreadID)
{
  uint index = dispatch_thread_id.x;
  double alpha = 0.0001 / (delta_time*delta_time);
  
  if (index >= max_particles)
    return;
  
  Particle particle = in_particles[index];
  uint num_constraints = particle.num_constraints;
  Double3 p0 = PackDouble(particle.p_x,particle.p_y,particle.p_z);
  for (int i = 0; i < num_constraints; ++i)
  {
    uint con_index = constraints_inds[index*10 + i];
    DistConstraint con = constraints[con_index];
    
    Particle other = in_particles[con.p1];
    if (index == con.p1)
    {
      other = in_particles[con.p0];
    }
    
    Double3 p1 = PackDouble(other.p_x,other.p_y,other.p_z);
    Double3 diff = Minus(p0,p1);
    double len = Length(diff);
    
    double C = len - con.l0;
    double lambda = -C / (particle.w + other.w + alpha);
    
    double grad_x = diff.x / len;
    double grad_y = diff.y / len;
    double grad_z = diff.z / len;

    p0.x += grad_x * 1.0 * lambda;
    p0.y += grad_y * 1.0 * lambda;
    p0.z += grad_z * 1.0 * lambda;
  }
  
  particle.p_x = p0.x;
  particle.p_y = p0.y;
  particle.p_z = p0.z;
  
  out_particles[index] = particle;
}