/* spline regression */
splinefit(argv)
{
local A, kn, kj, k0, k1, km1, c, ck, ck1, fit, restricted, xy;
/* parse args */
(y, x, knots, restricted, xy) = splinefit_parse_args(argv);
if (isscalar(knots))
{
/* get knot values */
if (knots > 0)
{
/* use quantiles */
// knots = quantile(x, linspace(0, 1, knots));
knots = yvals(percentile(x, linspace(0, 1, knots)));
}
else
{
/* linear spacing */
knots = linspace(min(x), max(x), abs(knots) + 2);
}
if (not(restricted))
{
/* remove end points */
knots[{1, end}] = {};
}
}
if (restricted)
{
if ((k = length(knots)) < 2)
{
error("splinefit - At least 2 knots required for restricted spline fit");
}
/* truncated design matrix, outlier knots supported */
A = ravel(ones(length(x), 1), x);
/* knot terms */
k0 = knots[k];
k1 = knots[k - 1];
km1 = k0 - k1;
kj = transpose(extract(knots, 1, k - 2));
/* build restricted knots columns - vectorized */
kn = (x > kj) * (x - kj)^3 -
(x > k1) * (x - k1)^3 * (k0 - kj) / km1 +
(x > k0) * (x - k0)^3 * (k1 - kj) / km1;
}
else
{
/* check knots */
if (min(knots) < min(x) || max(knots) > max(x))
{
error("splinefit - Knots must be located within X range for Unrestricted fit");
}
/* cubic design matrix */
A = ravel(ones(length(x), 1), x, x^2, x^3);
/* unrestricted knot columns - vectorized */
kj = transpose(knots);
kn = (x > kj) * (x - kj)^3;
}
/* add knots */
A = ravel(A, kn);
/* solve */
c = linfit(A, y);
/* fit */
fit = A *^ c;
/* properties */
fit.deltax = y.deltax;
fit.vunits = y.vunits;
fit.hunits = y.hunits;
if (xy)
{
fit = xy(x, fit);
}
/* return values */
if (outargc <= 1)
{
return(fit);
}
else
{
if (restricted)
{
/* compute c[k] and c[k+1] for restricted fit */
ck = sum(((knots[1..(k-2)] - k0) * c[3..k]) / km1);
ck1 = sum(((knots[1..(k-2)] - k1) * c[3..k]) / (-km1));
c = {c, ck, ck1};
}
if (outargc == 2)
{
return(fit, c);
}
else if (outargc == 3)
{
return(fit, c, knots);
}
else
{
/* sum of residuals squared */
sse = norm(y - fit)^2;
return(fit, c, knots, sse);
}
}
}
/* parse input args */
splinefit_parse_args(argv)
{
local x, y, knots, restricted = 0, j, arg, method = "", xy = 0, g;
x = y = knots = {};
loop (j = 1..argc)
{
arg = getargv(j);
if (isarray(arg))
{
if (isempty(y))
{
if (isxy(arg))
{
y = yvals(arg);
x = xvals(arg);
xy = 1;
}
else
{
y = arg;
}
}
else if (isempty(x))
{
x = arg;
}
else if (isempty(knots))
{
knots = arg;
}
}
else if (isscalar(arg))
{
knots = arg;
}
else if (isstring(arg))
{
method = arg;
}
}
if (isempty(knots))
{
if (not(isempty(y)) && not(xy))
{
if (not(isempty(x)))
{
knots = x;
x = xvals(y);
}
else
{
x = xvals(y);
}
}
if (isempty(knots))
{
knots = 10;
}
}
/* method string */
if (strlen(method) > 0)
{
switch (tolower(method))
{
case "natural":
case "restricted":
case "truncated":
restricted = 1;
break;
case "unrestricted":
case "cubic":
restricted = 0;
break;
default:
error(sprintf("splinefit - Unknown Method '%s'", method));
break;
}
}
if (isempty(x))
{
x = xvals(y);
setdeltax(x, 1.0);
}
if (xy)
{
/* reorder x in ascending order and corresponding y */
g = grade(x, 1);
x = x[g];
y = y[g];
}
return(y, x, knots, restricted, xy);
}