/* Copyright (C) 2013 E.J. Brambley

   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public License as
   published by the Free Software Foundation; either version 3 of the
   License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
   General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, see <http://www.gnu.org/licenses>.

   Additional permission under GNU GPL version 3 section 7

   If you modify this Program, or any covered work, by linking or
   combining it with D.E. Amos' Algorithm 644
   (http://www.netlib.org/toms/644) (or a modified version of that
   library), containing parts covered by the terms of the ACM Software
   Copyright and License Agreement
   (www.acm.org/publications/policies/softwarecrnotice), the licensors
   of this Program grant you additional permission to convey the
   resulting work.  Corresponding Source for a non-source form of such
   a combination shall include the source code for the parts of
   Algorithm 644 used as well as that of the covered work.

   If you modify this Program, or any covered work, by linking or
   combining it with LAPACK (http://www.netlib.org/lapack) (or a
   modified version of that library), containing parts covered by the
   terms of the LAPACK modified BSD license
   (http://www.netlib.org/lapack/LICENSE.txt), the licensors of this
   Program grant you additional permission to convey the resulting
   work.
*/

/* 
   This code forms the supplementary material of the publication
   Brambley & Gabard (2014), Journal of Sound and Vibration
   Please acknowledge use of this code by citing that publication.
*/

#include <complex.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "solve_for_surface_waves.h"


/**********************************/
/*                                */
/* Solve for roots of polynomials */
/*                                */
/**********************************/

static inline complex double ccbrt(complex double x)
{
  return cpow(x, 1.0/3.0);
}


static void solve_quartic(complex double* const w)
{
  complex double p, q, r, u;

  w[0] /= w[4];
  w[1] /= w[4];
  w[2] /= w[4];
  w[3] /= w[4];
  
  /* Calculate coefficients of x^4 + p x^2 + q^x + r = 0 */
  p = w[2] - 3.0/8.0*w[3]*w[3];
  q = w[1] - 0.5*w[2]*w[3] + w[3]*w[3]*w[3]/8.0;
  r = w[0] - 0.25*w[1]*w[3] + w[2]*w[3]*w[3]/16.0 - 3.0/256.0 * w[3]*w[3]*w[3]*w[3];

  /* Solve the cubic */
  u = -4.0/3.0*r - p*p/9.0;
  r = (3.0*q*q - 8.0*r*p)/6.0 + p*p*p/27.0;
  r = ccbrt(r + csqrt(u*u*u + r*r));
  u = r - u/r + p/3.0;

  /* Calcualate quadratic coefficients */
  p = csqrt(u-p);
  
  /* Solve quadratic */
  r = csqrt(p*p - 2.0*(u - q/p));
  w[0] = -0.5*(p + r) - 0.25*w[3];
  w[1] = -0.5*(p - r) - 0.25*w[3];
  r = csqrt(p*p - 2.0*(u + q/p));
  w[2] = 0.5*(p + r) - 0.25*w[3];
  w[3] = 0.5*(p - r) - 0.25*w[3];
}


/* ZGEEV from LAPACK */
void zgeev_(char* JOBVL, char* JOBVR,
	    int* N,
	    complex double* A, int* LDA,
	    complex double* W,
	    complex double* VL, int* LDVL,
	    complex double* VR, int* LDVR,
	    complex double* WORK, int* LWORK,
	    double* RWORK,
	    int* INFO);


/* Solves a polynomial a_0 + a_1 x + ... + a_n x^n = 0.0.  The array w
   contains the coefficients a_i = w[i].  On exit, the first n
   elements of w contain the n roots of the polynomial.  Return 0 on
   success.  */
static int solve_polynomial(int n, complex double* const w)
{
  char jobvl[] = "N";
  char jobvr[] = "N";
  complex double work = 0.0;
  complex double* work_space;
  complex double* vl = 0;
  complex double* vr = 0;
  complex double* a = 0;
  double* rwork_space = 0;
  int ldv = 1;
  int lwork = -1;
  int info = 0;
  int i;
  
  /* Check for zero leading coefficient */
  while (n > 0 && w[n] == 0.0)
    n--;
  if (n == 0)
    return n+2;

  /* Allocate some memory */
  a  = malloc(sizeof(complex double) * n*n);
  vl = malloc(sizeof(complex double) * n);
  vr = malloc(sizeof(complex double) * n);
  rwork_space = malloc(sizeof(double) * 2*n);
  if (!a || !vl || !vr || !rwork_space)
    {
      free(a);
      free(vl);
      free(vr);
      free(rwork_space);
      return n+1;
    }

  /* Setup the matrix */
  memset(a, 0, sizeof(complex double) * n*n);
  for (i = 0; i < n-1; i++)
    a[i*(n+1)+n] = 1.0;
  for (i = 0; i < n; i++)
    a[n-1 + i*n] = -w[i]/w[n];

  /* Get recommended work space size */
  zgeev_(jobvl, jobvr, &n, a, &n, w, vl, &ldv, vr, &ldv, &work, &lwork, rwork_space, &info);
  if (info)
    return info;

  /* Get work memory */
  lwork = (int) creal(work);
  if (!(work_space = malloc(sizeof(complex double) * lwork)))
    {
      free(a);
      free(vl);
      free(vr);
      free(rwork_space);
      return n+1;
    }

  /* Find the eigenvalues and eigenvectors */
  zgeev_(jobvl, jobvr, &n, a, &n, w, vl, &ldv, vr, &ldv, work_space, &lwork, rwork_space, &info);

  /* Free the work memory */
  free(work_space);
  free(rwork_space);
  free(vr);
  free(vl);
  free(a);

  /* Return status */
  return info;
}


/***************************/
/*                         */
/* Solve for surface waves */
/*                         */
/***************************/

/* Solve the surface wave dispersion relation of Brambley & Peake (2006, WM) for a locally-reacting impedance */
int solve_for_surface_waves(const complex double omega,
			    complex double* const surface_modes,
			    const complex double Z,
			    const double M)
{
  const complex double Z2 = Z*Z;

  if (fabs(M) < 1e-14)
    {
      surface_modes[0] = csqrt(omega*omega*(1.-1/Z2));
      surface_modes[1] = -surface_modes[0];
      return 2;
    }

  surface_modes[4] =                                       M*M*M*M        /Z2;
  surface_modes[3] =                          - 4.*      M*M*M*omega      /Z2;
  surface_modes[2] =     (1.-M*M)*omega*omega + 6.*    M*M*omega*omega    /Z2;
  surface_modes[1] =   2.*M*omega*omega*omega - 4.*  M*omega*omega*omega  /Z2;
  surface_modes[0] = -omega*omega*omega*omega +    omega*omega*omega*omega/Z2;

  solve_quartic(surface_modes);

  return 4;
}


/* Solve the surface wave dispersion relation of Brambley (2013, JSV) for a locally-reacting impedance */
int solve_for_surface_waves_bc(const complex double omega,
			       complex double* const surface_modes,
			       const complex double Z,
			       const double M,
			       const double d_mass,
			       const double d_mom,
			       const double d_ke,
			       const double d_1)
{
  if (fabs(M) < 1e-14)
    {
      surface_modes[0] = csqrt(omega*omega*(1.-1/((Z - I*omega*d_mass)*(Z - I*omega*d_mass))));
      surface_modes[1] = -surface_modes[0];
      return 2;
    }

  /* Write equation to solve as: (a0 + a1*k + a2*k^2)^2*(b0 + b1*k + b2*k^2) - (c0 + c1*k + c2*k^2 + c3*k^3)^2 = 0 */
  const complex double a0 = I*omega*Z + omega*omega*d_mass;
  const complex double a1 = -2.*omega*M*d_mom;
  const complex double a2 = M*M*d_ke;
  const complex double b0 = -omega*omega;
  const complex double b1 = 2.*omega*M;
  const complex double b2 = 1.-M*M;
  const complex double c0 = omega*omega;
  const complex double c1 = -2.*omega*M;
  const complex double c2 = M*M;
  const complex double c3 = I*Z*d_1*M;

  /* Rewrite this as: (d0 + d1*k + d2*k^2 + d3*k^3 + d4*k^4)*(b0 + b1*k + b2*k^2) - (e0 + e1*k + e2*k^2 + e3*k^3 + e4*k^4 + e5*k^5 + e6*k^6) = 0 */
  const complex double d0 =    a0*a0;
  const complex double d1 = 2.*a0*a1;
  const complex double d2 = 2.*a0*a2 +    a1*a1;
  const complex double d3 = 2.*a1*a2;
  const complex double d4 =    a2*a2;
  const complex double e0 =    c0*c0;
  const complex double e1 = 2.*c0*c1;
  const complex double e2 = 2.*c0*c2 +    c1*c1;
  const complex double e3 = 2.*c0*c3 + 2.*c1*c2;
  const complex double e4 = 2.*c1*c3 +    c2*c2;
  const complex double e5 = 2.*c2*c3;
  const complex double e6 =    c3*c3;

  /* Work our polynomial in k to solve */
  surface_modes[0] =                 d0*b0 - e0;
  surface_modes[1] =         d0*b1 + d1*b0 - e1;
  surface_modes[2] = d0*b2 + d1*b1 + d2*b0 - e2;
  surface_modes[3] = d1*b2 + d2*b1 + d3*b0 - e3;
  surface_modes[4] = d2*b2 + d3*b1 + d4*b0 - e4;
  surface_modes[5] = d3*b2 + d4*b1         - e5;
  surface_modes[6] = d4*b2                 - e6;

  if (solve_polynomial(6, surface_modes))
    return 0;

  return 6;
}
