View Raw SPL
/* spline regression */
splinefit(argv)
{
        local A, kn, kj, k0, k1, km1, c, ck, ck1, fit, restricted, xy;

        /* parse args */
        (y, x, knots, restricted, xy) = splinefit_parse_args(argv);

        if (isscalar(knots))
        {
                /* get knot values */
                if (knots > 0)
                {
                        /* use quantiles */
//                        knots = quantile(x, linspace(0, 1, knots));
                        knots = yvals(percentile(x, linspace(0, 1, knots)));
                }
                else
                {
                        /* linear spacing */
                        knots = linspace(min(x), max(x), abs(knots) + 2);
                }

                if (not(restricted))
                {
                        /* remove end points */
                        knots[{1, end}] = {};
                }
        }

        if (restricted)
        {
                if ((k = length(knots)) < 2)
                {
                        error("splinefit - At least 2 knots required for restricted spline fit");
                }

                /* truncated design matrix, outlier knots supported */
                A = ravel(ones(length(x), 1), x);

                /* knot terms */
                k0  = knots[k];
                k1  = knots[k - 1];
                km1 = k0 - k1;
                kj  = transpose(extract(knots, 1, k - 2));

                /* build restricted knots columns - vectorized */
                kn  = (x > kj) * (x - kj)^3 -
                          (x > k1) * (x - k1)^3 * (k0 - kj) / km1 + 
                        (x > k0) * (x - k0)^3 * (k1 - kj) / km1;
        }
        else
        {
                /* check knots */
                if (min(knots) < min(x) || max(knots) > max(x))
                {
                        error("splinefit - Knots must be located within X range for Unrestricted fit");
                }

                /* cubic design matrix */
                A = ravel(ones(length(x), 1), x, x^2, x^3);

                /* unrestricted knot columns - vectorized */
                kj = transpose(knots);
                kn = (x > kj) * (x - kj)^3;
        }

        /* add knots */
        A = ravel(A, kn);

        /* solve */
        c = linfit(A, y);

        /* fit */
        fit = A *^ c;

        /* properties */
        fit.deltax = y.deltax;
        fit.vunits = y.vunits;
        fit.hunits = y.hunits;

        if (xy)
        {
                fit = xy(x, fit);
        }

        /* return values */
        if (outargc <= 1)
        {
                return(fit);
        }
        else
        {
                if (restricted)
                {
                        /* compute c[k] and c[k+1] for restricted fit */
                        ck  = sum(((knots[1..(k-2)] - k0) * c[3..k]) / km1);
                        ck1 = sum(((knots[1..(k-2)] - k1) * c[3..k]) / (-km1));

                        c  = {c, ck, ck1};
                }

                if (outargc == 2)
                {
                        return(fit, c);
                }
                else if (outargc == 3)
                {
                        return(fit, c, knots);
                }
                else
                {
                        /* sum of residuals squared */
                        sse = norm(y - fit)^2;

                        return(fit, c, knots, sse);
                }
        }
}


/* parse input args */
splinefit_parse_args(argv)
{
        local x, y, knots, restricted = 0, j, arg, method = "", xy = 0, g;

        x = y = knots = {};

        loop (j = 1..argc)
        {
                arg = getargv(j);

                if (isarray(arg))
                {
                        if (isempty(y))
                        {
                                if (isxy(arg))
                                {
                                        y  = yvals(arg);
                                        x  = xvals(arg);
                                        xy = 1;
                                }
                                else
                                {
                                        y = arg;
                                }
                        }
                        else if (isempty(x))
                        {
                                x = arg;
                        }
                        else if (isempty(knots))
                        {
                                knots = arg;
                        }
                }
                else if (isscalar(arg))
                {
                        knots = arg;
                }
                else if (isstring(arg))
                {
                        method = arg;
                }
        }

        if (isempty(knots))
        {
                if (not(isempty(y)) && not(xy))
                {
                        if (not(isempty(x)))
                        {
                                knots = x;
                                x     = xvals(y);
                        }
                        else
                        {
                                x = xvals(y);
                        }
                }

                if (isempty(knots))
                {
                        knots = 10;
                }
        }

        /* method string */
        if (strlen(method) > 0)
        {
                switch (tolower(method))
                {
                        case "natural":
                        case "restricted":
                        case "truncated":
                                restricted = 1;
                                break;

                        case "unrestricted":
                        case "cubic":
                                restricted = 0;
                                break;

                        default:
                                error(sprintf("splinefit - Unknown Method '%s'", method));
                                break;
                }
        }

        if (isempty(x))
        {
                x = xvals(y);

                setdeltax(x, 1.0);
        }

        if (xy)
        {
                /* reorder x in ascending order and corresponding y */ 
                g = grade(x, 1);
                x = x[g];
                y = y[g];
        }

        return(y, x, knots, restricted, xy);
}