/*
 * cpldosc_slowvf_solve.c
 *
 *  This function computes the solution of the slow flow of the
 *  coupled oscillator system, and stops when the solution reaches
 *  a fold.
 */


#include <stdio.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_odeiv.h>
#include <math.h>

double evt_stepsize(const gsl_odeiv_system *sys, double *y,
                    const double hmin, const double hmax,
                    double (*eventfunc)(const double *y, const double *params),
                    void (*grad_eventfunc)(const double *y, const double *params,
                                                              double *grad),
                    double *f );

double cpldosc_foldfunc(double *x, double *p);
int cpldosc_foldfunc_grad(double *x, double *p, double *grad);
int cpldosc_fhat(double *x, double *p, double *q);
int cpldosc_slowvf(double t, const double x_[], double * vf_, void * params);
int cpldosc_slowvf_jac(double t, const double x_[], double *jac_, double *dfdt_, void *params);


/*
 *  cpldosc_slowvf_solve
 *
 *  This function uses a GSL ODE solver to solve the slow differential equations
 *  of the coupled oscillator system.
 *
 *  Input:
 *     y[2] = (v1,v2)                Initial condition.  When the function returns,
 *                                   this will hold the final point.
 *     p[3] = (sigma1,sigma2,omega)  Parameters
 *     tfinal                        Stop when t=tfinal.  We expect to stop at a
 *                                   fold, so typically the function will stop *before*
 *                                   tfinal is reached.
 *
 *  Points in the solution are printed to stdout, in the format
 *     t v1 v2 q1 q2 foldfunc
 *  (This should probably be changed to write to a file, or to save the points
 *  in an array.)
 * 
 *  Return values:
 *     0    Solution stopped at a fold
 *     1    Solution stopped at t=tfinal
 */

int cpldosc_slowvf_solve(double y[2], double p[3], double tfinal)
    {
    double prevt, prevy[2];
    double q[2];
    /*
     * ODE solver control parameters.  These should probably inputs to the
     * function, rather than being set to specific value here.
     */
    double hmax = 0.1;
    double hmin = 1e-6;
    double abserr = 1e-12;
    double relerr = 0.0;
    /*
     *  ztol is the tolerance used to detect when a fold or an equilibrium
     *  is reached.  This should also probably be passed in as an argument.
     */
    double ztol = 1e-12;

    /*
     *  Set up the GSL ODE solver.
     */
    const gsl_odeiv_step_type * T  = gsl_odeiv_step_rk8pd;
    gsl_odeiv_step * step    = gsl_odeiv_step_alloc(T, 2);
    gsl_odeiv_control * control = gsl_odeiv_control_y_new(abserr, relerr);
    gsl_odeiv_evolve * evolve  = gsl_odeiv_evolve_alloc(2);
    gsl_odeiv_system sys = {cpldosc_slowvf, cpldosc_slowvf_jac, 2, &(p[0])};

    /*
     *  Get the value of the fold function at the starting point.
     */
    double g0 = cpldosc_foldfunc(y,p);
    /*
     *  dir determines whether events are detected only when the fold function
     *  is increasing or decreasing.  In this case, we set dir=0 because we
     *  want to stop in either case.
     */
    int dir = 0;  /* 1=increasing, -1=decreasing */

    double t  = 0.0;
    double h = hmax;
    while (t < tfinal)
        {
        prevt    = t;
        prevy[0] = y[0];
        prevy[1] = y[1];

        cpldosc_fhat(y,p,q);
        printf("%16.12e %16.12e %16.12e %16.12e %16.12e %16.12e\n", t, y[0], y[1], q[0],q[1],cpldosc_foldfunc(y,p));

        /*
         *  Compute the step size.
         */
        double hevt = evt_stepsize(&sys,y,hmin,hmax,cpldosc_foldfunc,cpldosc_foldfunc_grad,NULL);
        if (hevt < 0.0)
            {
            fprintf(stderr,"evt_stepsize returned %f\n",hevt);
            break;
            }
        if (h > hevt)
            {
            h = hevt;
            }
        /*
         *  Advance the solution by a step.
         */
        int status = gsl_odeiv_evolve_apply(evolve, control, step, &sys, &t, tfinal, &h, y);
        if (status != GSL_SUCCESS)
            {
            fprintf(stderr,"status=%d\n",status);
            break;
            }
        /*
         *  Check for a change of sign of the event function, and handle it.
         */
        double g1 = cpldosc_foldfunc(y,p);
        if (g1*g0 <= 0.0)
            {
            /* The sign of the event function changed. */
            if (g1 == 0.0 & (dir == 0 | (g0 > 0 & dir == -1) | (g0 < 0 & dir == 1)) )
                {
                /* Remarkably, the solution landed exactly on the event boundary! */
                break;
                }
            else if  ((dir == 0) | (g1 < 0.0 & dir == -1) | (g1 > 0.0 & dir == 1))
                {
                /*
                 *  Use the ODE step function and the method of regula falsi
                 *  to find the event.
                 */
                double prevg = cpldosc_foldfunc(prevy,p);
                double m;
                /* Recalculate h, since gsl_odeiv_evolve_apply may have changed it.*/
                h = t - prevt; 
                double dydt_in[2];
                double dydt_out[2];
                GSL_ODEIV_FN_EVAL(&sys,t,y,dydt_in);
                int maxiters = 10;
                while ((fabs(g1) > ztol) & (maxiters > 0))
                    {
                    double y_err[2];

                    --maxiters;
                    m = g1/(prevg-g1);
                    h = m*h;
                    prevy[0] = y[0];
                    prevy[1] = y[1];
                    prevg    = g1;
                    int status = gsl_odeiv_step_apply(step,t,h,y,y_err,dydt_in,dydt_out,&sys);
                    if (status != GSL_SUCCESS)
                        {
                        fprintf(stderr,"gsl_odeiv_step_apply returned an error while refining the event location.\n");
                        break;
                        }
                    dydt_in[0] = dydt_out[0];
                    dydt_in[1] = dydt_out[1];
                    g1 = cpldosc_foldfunc(y,p);
                    }
                if (maxiters == 0)
                    {
                    fprintf(stderr,"regula falsi did not converge; after maxiters steps, we have:\n");
                    }
                /*
                 *  If we reach here, the above loop terminated because the solution reached
                 *  a point where the fold function is less than ztol. Print this point,
                 *  and break out of the main loop.
                 */
                cpldosc_fhat(y,p,q);
                printf("%16.12e %16.12e %16.12e %16.12e %16.12e %16.12e\n", t, y[0], y[1], q[0],q[1],cpldosc_foldfunc(y,p));
                break;
                }
            else
                {
                /* Wrong direction, don't stop here. */
                g0 = g1;
                }
            }
        } /* end while -- end of the main loop */

    gsl_odeiv_evolve_free(evolve);
    gsl_odeiv_control_free(control);
    gsl_odeiv_step_free(step);

    int status;
    if (t >= tfinal)
        {
        /* Reached t=tfinal before hitting a fold. */
        status = 1;
        cpldosc_fhat(y,p,q);
        printf("%16.12e %16.12e %16.12e %16.12e %16.12e %16.12e\n", t, y[0], y[1], q[0],q[1],cpldosc_foldfunc(y,p));
        }
    else
        status = 0;
    return status;
    }

