// zhesv.cc
// 複素エルミート行列の線形方程式を解く

#include <stdio.h>
#include <complex>
typedef std::complex<double> Complex;

extern "C" {
  void zhesv_ ( const char& UPLO, 
		const int& N, const int& NRHS,
		Complex** A, const int& LDA, int* IPIV,
		Complex** B, const int& LDB,
		Complex* WORK, const int& LWORK,
		int& INFO, int UPLOlen );
};

// 複素エルミート行列 A の線形方程式 A x = b を解く簡易関数
// Input: A[N][N], b[N]  Output: x[N]
// ただし、A[i][j<=i] の下三角しか参照されない
//
template <int N> int zhesv( Complex A[N][N], Complex x[N], Complex b[N] )
{
  int i, j, info;
  static int ipiv[N];
  static Complex work[4*N];
  static Complex U[N][N];

  for( i=0; i<N; i++ ){
    for( j=0; j<N; j++ ){
      U[i][j] = A[j][i];
    }
    x[i] = b[i];
  }

  zhesv_( 'L', N, 1, (Complex**)U, N, ipiv, (Complex**)x, N, work, 4*N, info, 1 );

  return info;
}

int main(void)
{
  const int N = 4;
  int i, j, info;

  Complex A[N][N];
  Complex x[N], b[N];

  for( i=0; i<N; i++ ){
    for( j=0; j<N; j++ ){
      A[i][j] = 1+ i+j;
    }
    b[i] = 5.0;
  }

  info = zhesv( A, x, b );

  printf("# info=%d.\n", info );

  printf("# Solution.\n");
  for( j=0; j<N; j++ ){
    printf("%+f%+f\n", real(x[j]), imag(x[j]) );
  }

  printf("# Error.\n");
  for( i=0; i<N; i++ ){
    Complex sum=0.0;
    for( j=0; j<N; j++ ){
      sum += A[i][j]*x[j];
    }
    sum -= b[i];

    printf("%+f%+f\n", real(sum), imag(sum) );
  }

  return 0;
}