//
// ginac_write_code.cpp
//
// A collection of functions that write C code from
// ginac expressions.
//
// Warren Weckesser
// Department of Mathematics
// Colgate University
// Hamilton, NY 13346
//

#include <fstream>
#include <iostream>
#include <string>
#include <ginac/ginac.h>

using namespace std;
using namespace GiNaC;


void DeclareDouble(ofstream *fout, lst names)
    {
    int n;

    n = names.nops();
    if (n < 1)
        return;
    *fout << "    double ";
    for (int i = 0; i < n; ++i)
        {
        if (i > 0)
            *fout << ", ";
        *fout << names[i];
        }
    *fout << ";" << endl;
    }

void DeclareAuxDouble(ofstream *fout, lst aux)
    {
    int n;

    n = aux.nops();
    if (n < 1)
        return;
    *fout << "    double ";
    for (int i = 0; i < n; ++i)
        {
        if (i > 0)
            *fout << ", ";
        *fout << aux[i].lhs();
        }
    *fout << ";" << endl;
    }


void GetFromGSLVector(ofstream *fout, lst names, char *vector)
    {
    int n;

    n = names.nops();
    for (int i = 0; i < n; ++i)
        {
        *fout << "    ";
        fout->width(9);
        *fout << names[i];
        fout->width(0);
        *fout << " = gsl_vector_get(" << vector << "," << i << ");" << endl;
        }
    }

void GetFromCArray(ofstream *fout, lst names, char *vector)
    {
    int n;

    n = names.nops();
    for (int i = 0; i < n; ++i)
        {
        *fout << "    ";
        fout->width(9);
        *fout << names[i];
        fout->width(0);
        *fout << " = ((double *)" << vector << ")[" << i << "];" << endl;
        }
    }


//
// In the following function, mode is a string that can be "c" or "gsl".
// This determine how the vectors are stored.  "gsl" results in code that
// uses the GNU Scientific Library vector and matrix representations.
// "c" results in code that uses C arrays.
// 

int WriteFunction(string mode, string name, ex expr, lst vars, lst params, lst aux)
    {
    string filename = name + ".c";
    ofstream fout;
    fout.open(filename.c_str());

    fout << csrc << left; 


    fout << "/*" << endl;
    fout << " *  C file for " << name << endl;
    fout << " *" << endl;
    fout << " */" << endl;
    fout << endl;
    fout << "#include <math.h>" << endl;
    if (mode == "gsl")
        {
        fout << "#include <gsl/gsl_errno.h>" << endl;
        fout << "#include <gsl/gsl_matrix.h>" << endl;
        }
    fout << endl;

    fout << "double " << name ;
    if (mode == "gsl")
        fout << "(gsl_vector* x_, gsl_vector* p_)\n";
    else
        fout << "(const double* x_, const double* p_)\n";
    fout << "    {\n";
    DeclareDouble(&fout,vars);
    DeclareDouble(&fout,params);
    DeclareAuxDouble(&fout,aux);
    fout << "    double " << name << "_;\n";
    fout << endl;
    if (mode == "gsl")
        {
        GetFromGSLVector(&fout,vars,"x_");
        GetFromGSLVector(&fout,params,"p_");
        }
    else
        {
        GetFromCArray(&fout,vars,"x_");
        GetFromCArray(&fout,params,"p_");
        }
    fout << endl;
    int naux = aux.nops();
    for (int i = 0; i < naux; ++i)
        {
        fout << "    " << aux[i].lhs() << " = " << aux[i].rhs() << ";\n";
        }
    fout << "    " << name << "_ = " << expr << ";\n";
    fout << "    return(" << name << "_);\n";
    fout << "    }\n";
    }

int WriteMatrixFunction(string mode, string name, matrix J, lst vars, lst params, lst aux)
    {
    string filename = name + ".c";
    ofstream fout;
    fout.open(filename.c_str());

    fout << csrc << left; 


    fout << "/*" << endl;
    fout << " *  C file for " << name << endl;
    fout << " *" << endl;
    fout << " */" << endl;
    fout << endl;
    fout << "#include <math.h>" << endl;
    if (mode == "gsl")
        {
        fout << "#include <gsl/gsl_errno.h>" << endl;
        fout << "#include <gsl/gsl_matrix.h>" << endl;
        }
    fout << endl;

    fout << "int " << name ;
    if (mode == "gsl")
        fout << "(gsl_vector* x_, gsl_vector* p_, gsl_matrix* J_)\n";
    else
        fout << "(double* x_, double* p_, double J_[" << J.rows() << "][" << J.cols() << "])\n";
    fout << "    {\n";
    DeclareDouble(&fout,vars);
    DeclareDouble(&fout,params);
    DeclareAuxDouble(&fout,aux);
    fout << endl;
    if (mode == "gsl")
        {
        GetFromGSLVector(&fout,vars,"x_");
        GetFromGSLVector(&fout,params,"p_");
        }
    else
        {
        GetFromCArray(&fout,vars,"x_");
        GetFromCArray(&fout,params,"p_");
        }
    fout << endl;
    int naux = aux.nops();
    for (int i = 0; i < naux; ++i)
        {
        fout << "    " << aux[i].lhs() << " = " << aux[i].rhs() << ";\n";
        }
    for (int i = 0; i < J.rows(); ++i)
        for (int j = 0; j < J.cols(); ++j)
            {
            if (mode == "gsl")
                fout << "    gsl_matrix_set(J_," << i << "," << j << "," << J(i,j) << ");\n";
            else
                fout << "    J_[" << i << "][" << j << "] = " << J(i,j) << ";\n";
            }

    if (mode == "gsl")
        fout << "    return(GSL_SUCCESS);\n";
    else
        fout << "    return(0);\n";
    fout << "    }\n";
    }

//
// WriteVectorFunction creates a function to compute the nx1 matrix v.
//

int WriteVectorFunction(string mode, string name, matrix v, lst vars, lst params, lst aux)
    {
    string filename = name + ".c";
    ofstream fout;
    fout.open(filename.c_str());

    fout << csrc << left; 


    fout << "/*" << endl;
    fout << " *  C file for " << name << endl;
    fout << " *" << endl;
    fout << " */" << endl;
    fout << endl;
    fout << "#include <math.h>" << endl;
    if (mode == "gsl")
        {
        fout << "#include <gsl/gsl_errno.h>" << endl;
        fout << "#include <gsl/gsl_matrix.h>" << endl;
        }
    fout << endl;

    fout << "int " << name;
    if (mode == "gsl")
        fout << "(gsl_vector* x_, gsl_vector* p_, gsl_vector* v_)\n";
    else
        fout << "(double* x_, double* p_, double* v_)\n";
    fout << "    {\n";
    DeclareDouble(&fout,vars);
    DeclareDouble(&fout,params);
    DeclareAuxDouble(&fout,aux);
    fout << endl;
    if (mode == "gsl")
        {
        GetFromGSLVector(&fout,vars,"x_");
        GetFromGSLVector(&fout,params,"p_");
        }
    else
        {
        GetFromCArray(&fout,vars,"x_");
        GetFromCArray(&fout,params,"p_");
        }
    fout << endl;
    int naux = aux.nops();
    for (int i = 0; i < naux; ++i)
        {
        fout << "    " << aux[i].lhs() << " = " << aux[i].rhs() << ";\n";
        }
    fout << endl;
    for (int i = 0; i < v.rows(); ++i)
        {
        if (mode == "gsl")
            fout << "    gsl_vector_set(v_," << i << "," << v(i,0) << ");\n";
        else
            fout << "    v_[" << i << "] = " << v(i,0) << ";\n";
        }
    fout << endl;
    if (mode == "gsl")
        fout << "    return(GSL_SUCCESS);\n";
    else
        fout << "    return(0);\n" ;
    fout << "    }\n";
    }


//
// WriteGSLVectorField creates a function to be used with the GSL ODEIV suite.
//

int WriteGSLVectorField(string name, matrix v, lst vars, lst params, lst aux)
    {
    string filename = name + ".c";
    ofstream fout;
    fout.open(filename.c_str());

    fout << csrc << left; 


    fout << "/*" << endl;
    fout << " *  GSL Vector Field file for " << name << endl;
    fout << " *" << endl;
    fout << " */" << endl;
    fout << endl;
    fout << "#include <math.h>" << endl;
    fout << "#include <gsl/gsl_errno.h>" << endl;
    fout << "#include <gsl/gsl_matrix.h>" << endl;
    fout << endl;

    fout << "int " << name << "(double t, double x_[], double *vf_, void *par_)\n";
    fout << "    {\n";
    DeclareDouble(&fout,vars);
    DeclareDouble(&fout,params);
    DeclareAuxDouble(&fout,aux);
    fout << endl;
    GetFromCArray(&fout,vars,"x_");
    GetFromCArray(&fout,params,"par_");
    fout << endl;
    int naux = aux.nops();
    for (int i = 0; i < naux; ++i)
        {
        fout << "    " << aux[i].lhs() << " = " << aux[i].rhs() << ";\n";
        }
    fout << endl;
    for (int i = 0; i < v.rows(); ++i)
        {
        fout << "    vf_[" << i << "] = " << v(i,0) << ";\n";
        }
    fout << "    return(GSL_SUCCESS);\n";
    fout << "    }\n";
    }

//
// WriteGSLVectorFieldJacobian creates a function to be used with the GSL ODEIV suite.
// (This code assumes that the vector field is autonomous.)
//

int WriteGSLVectorFieldJacobian(string name, matrix J, lst vars, lst params, lst aux)
    {
    string filename = name + ".c";
    ofstream fout;
    fout.open(filename.c_str());

    fout << csrc << left; 


    fout << "/*" << endl;
    fout << " *  GSL Vector Field Jacobian file for " << name << endl;
    fout << " *" << endl;
    fout << " */" << endl;
    fout << endl;
    fout << "#include <math.h>" << endl;
    fout << "#include <gsl/gsl_errno.h>" << endl;
    fout << "#include <gsl/gsl_matrix.h>" << endl;
    fout << endl;

    fout << "int " << name << "(double t, double x_[], double *jac_, double *dfdt_, void *par_)\n";
    fout << "    {\n";
    DeclareDouble(&fout,vars);
    DeclareDouble(&fout,params);
    DeclareAuxDouble(&fout,aux);
    fout << endl;
    GetFromCArray(&fout,vars,"x_");
    GetFromCArray(&fout,params,"par_");
    fout << endl;
    int naux = aux.nops();
    for (int i = 0; i < naux; ++i)
        {
        fout << "    " << aux[i].lhs() << " = " << aux[i].rhs() << ";\n";
        }
    fout << endl;
    fout << "    gsl_matrix_view jac_mat = gsl_matrix_view_array(jac_," << J.rows() << "," << J.cols() << ");\n";
    fout << "    gsl_matrix *J_ = &jac_mat.matrix;\n";
    fout << endl;
    for (int i = 0; i < J.rows(); ++i)
        for (int j = 0; j < J.cols(); ++j)
            {
            fout << "    gsl_matrix_set(J_," << i << "," << j << "," << J(i,j) << ");\n";
            fout << endl;
            }
    for (int i = 0; i < J.rows(); ++i)
        fout << "    dfdt_[" << i << "] = 0.0;\n";
    fout << endl;
    fout << "    return(GSL_SUCCESS);\n";
    fout << "    }\n";
    }

