How can I add custom gradients in the Haskell autodifferentiation library "ad"? - haskell

If I want to give a custom or known gradient for a function, how can I do that in the ad library? (I don't want to autodifferentiate through this function.) I am using the grad function in this library.
If the library doesn't provide this feature, is there some way I can easily implement this functionality myself, perhaps by changing the definitions of leaf nodes in ad or by editing the dual numbers that presumably carry the numerical gradients?
Here's a concrete example of what I mean:
Say I have some function I want to take the gradient of, say f(x, y) = x^2 + 3 * g(x, y)^2. Then say that g(x, y) is a function whose definition is complicated, but whose gradient I've already calculated analytically and is quite simple. Thus, when I take grad f and evaluate it at a point (x, y), I'd like to just plug in my custom gradient for g, instead of autodiffing through it: something like my_nice_grad_of_g (x, y).
I see other autodiff libraries do provide this feature, for example Stan and Tensorflow both allow users to define gradients of a function.

Related

Insulate a segment of code from jax tracing

Apologies in advance for how vague this question is (unfortunately I don't know enough about how jax tracing works to phrase it more precisely), but: Is there a way to completely insulate a function or code block from jax tracing?
For context, I have a function of the form:
def f(x, y):
z = h(y)
return g(x, z)
Essentially, I want to call g(x, z), and treat z as a constant when doing any jax transformations. However, setting up the argument z is very awkward, so the helper function h is used to transform an easier-to-specify input y into the format required by g. What I'd like is for jax to treat h as a non-traceable black box, so that doing jit(lambda x: f(x, y0)) for a particular y0 is the same as first computing z0 = h(y0) with numpy, then doing jit(lambda x: g(x, z0)) (and similar with grad or whatever other function transformations).
In my code, I've already written h to only use standard numpy (which I thought might lead to black-box behaviour), but the compile time of jit(lambda x: f(x, y0)) is noticeably longer than the compile time of jit(lambda x: g(x, z0)) for z0 = h(y0). I have a feeling the compile time may have something to do with jax tracing the many loops in h, though I'm not sure.
Some additional notes:
Writing h in a jax-friendly way would be awkward (input formatting is ragged, tons of looping/conditionals, output shape dependent on input value, etc) and ultimately more trouble than it's worth as the function is extremely cheap to execute, and I don't ever need to differentiate it (the input data is integer-based).
Thoughts?
Edit addition for clarity: I know there are maybe ways around this if, e.g. f is a top-level function. In this case it isn't such a big deal to get the user to call h first to "pre-compile" the jax-friendly inputs to g, then freely perform whatever jax transformations they want to lambda x: g(x, z0). However, I'm imagining cases in which we have many functions that we want to chain together, that have the same structure as f, where there are some jax-unfriendly inputs/computations, but these inputs will always be treated as constant to the jax part of the computation. In principle one could always pull out these pre-computations to set up the jax stuff, but this seems difficult if we have a non-trivial collection of functions of this type that will be calling each other.
Is there some way to control how f gets traced, so that while tracing it knows to just evaluate z=h(y) (instead of tracing h) then continue with tracing g(x, z)?
f_jitted = jax.jit(f, static_argnums=1)
static_argnums parameter probably could help
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
You can use transformation parameters such as static_argnums for jit to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

Is there a way to supply a numerical function to JiTCODE’s function argument instead of symbolic one?

I am getting a function (a learned dynamical system) through a neural network and want to pass it to JiTCODE to calculate trajectories, Lyapunov exponents, etc. As per the JiTCODE documentation, the function f has to be a symbolic function. Is there any way to change this since ultimately JiTCODE is going to lambdify the symbolic function?
Basically, this is what I'm doing right now:
# learns derviates from the Neural net model
# returns an array of numbers [\dot{x},\dot{y}] for input [x,y]
learned_fn = lambda t, y0: NN_model(t, y0)
ODE = jitcode_lyap(learned_fn, n_lyap=2)
ODE.set_integrator("vode")
First beware that JiTCODE does not take regular functions like your learned_fn as an input. It takes either iterables of symbolic expressions or generator functions returning symbolic expressions. This is why your example code will likely produce an error.
What you are asking for
You can “inject” any derivative with the right signature into JiTCODE by changing the f property and telling it that it failed compiling the actual derivative. Here is a minimal example doing this:
from jitcode import jitcode, y
ODE = jitcode([0])
ODE.f = lambda t,y: y[0]
ODE.compile_attempt = False
ODE.set_integrator("dopri5")
ODE.set_initial_value([1],0.0)
for time in range(30):
print(time,*ODE.integrate(time))
Why you probably do not want to do this
Ignoring Lyapunov exponents for a second, the entire point of JiTCODE is to hard-code your derivative for you and pass it to SciPy’s ode or solve_ivp who perform the actual integration. Thus the above example code is just an overly complicated way of passing a function to one SciPy’s standard integrators (here ode), with no advantage. If your NN_model is very efficiently implemented in the first place, you may not even gain a speed boost from JiTCODE’s auto-compilation.
The main reason to use JiTCODE’s Lyapunov-exponent capabilities is that it automatically obtains the Jacobian and the ODE for the tangent-vector evolution (needed for the Benettin method) from the symbolic representation of the derivative. Without a symbolic input, it cannot possibly do this. You could theoretically inject a tangent-vector ODE as well, but then again you would leave little for JiTCODE to do and you would probably better off using SciPy’s ode or solve_ivp directly.
What you probably need
If you want to use JiTCODE, you need to write a small piece of code that translates the output of your neural-network training to a symbolic representation of your ODE as needed by JiTCODE. This is probably much less scary than it sounds. You just need to obtain the trained coefficients and insert it in the equations of the general form of the neural network.
If you are lucky and your NN_model fully supports duck typing (and ), you may do something like this:
from jitcode import t,y
n = 10 # dimension of your ODE
NN_input = [y(i) for i in range(n)]
learned_fn = NN_model(t,NN_input)[1]
The idea is that you feed NN_model once with abstract symbolic input (t and NN_input). NN_model then once acts on this abstract input providing you an abstract result (here you need the duck-typing support). If I interpreted the output of your NN_model correctly, the second component of this result should be the abstract derivative as required by JiTCODE as an input.
Note that your NN_model appears to expect dimensions to be indices, but JiTCODE’s y expects dimensions to be function arguments. Thus you cannot just choose NN_input = y, but you have to transform it as above.
To quote directly from the linked documentation
JiTCODE takes an iterable (or generator function or dictionary) of symbolic expressions, which it translates to C code, compiles on the fly,
so there is no lambdification going on, the function is parsed, not just evaluated.
But in general that should be no problem, you just use the JITCODE provided symbolic vector y and symbol t instead of the function arguments t,y of the right side of the ODE.

fitting for offset in a patsy model

Using patsy, I understand how to turn intercepts on or off. But I haven't managed to get horizontal offsets. For instance, I would like to be able to fit, in essence
y = alpha + beta * abs(x_opt - x_obs)
with x_opt free in the fit. I tried write this like so:
y ~ 1 + np.abs(y - x)
using a constant column for y. But within the np.abs() parentheses, patsy "turns off," and y - x is just interpreted as a number. If I shift y to 1 or 20, I get different answers.
A similar question applies for e.g., np.pow(1-x, 2) or a sine wave. Being able to fit for the x offset would be extremely helpful. Is this possible? Or is this precisely what is meant that patsy doesn't do non-linear?
patsy and most of statsmodels only handle models that are linear in parameters. Or more precisely, models where the design matrix and estimated parameters are combined in a linear way, x * beta.
Polynomials and splines are nonlinear in the underlying variables but have a linear representation in terms of basis function and are therefore linear in parameters.
The only non-linearities in the models that are currently implemented in statsmodels are predefined nonlinearities like link functions in GLM or discrete models, shape parameters in models like NegativeBinomial, or covariances in mixed models and GEE.
The best Python package for nonlinear least squares is currently lmfit https://pypi.python.org/pypi/lmfit/

How to set up sympy to perform standard differential geometry tasks?

I'm an engineering student. Pretty much all math I have to do is something in R2 or R3 and concerns differential geometry. Naturally I really like sympy because it makes my calculations reusable and presentable.
What I found:
The thing in sympy that comes closeset to what I know functions as, which is as mapping of scalar or vector values to scalar or vector values, with a name and connected to an expressions seems to be something of the form
functionname=sympy.Lambda(Variables in tuple, Expression)
or as an example
f=sympy.Lambda((x),x+1)
I also found that sympy has the diffgeom module that defines Manifolds, Patches and can then perform some operations on functions without expressions or points. Like translating a point in a coordinate system to the same point in a different, linked coordinate system.
I haven't found a way to perform those operations and transformations on functions like those above. Or to define something in the diffgeom context that performs like the Lambda function.
Examples of what I'd like to do:
scalarfield f (x,y,z) = expression
grad (f) = ( d expression / dx , d expression / dy , d expression / dz)^T
vectorfield v (x,y,z) = ( expression 1 , expression 2 , expression 3 )^T
I'd then like to be able to integrate the vectorfield over bodies or curves.
Do these things exist and I haven't found them?
Are they doable with diffgeom and I didn't understand it?
Would I have to write this myself with the backbones that sympy already provides?
There is a differential geometry module within sympy:
http://docs.sympy.org/latest/modules/diffgeom.html
For more examples you can see http://blog.krastanov.org/pages/diff-geometry-in-python.html
To do the suggested in the diffgeom module, just define your expression using the base coordinates of your manifold:
from diffgeom.rn import R2
scalar = (R2.r**2 - R2.x**2 - R2.y**2) # you can mix coordinate systems
gradient = (R2.e_x + R2.e_y).rcall(scalar)
There are various functions for change of coordinates, etc. Probably many things are missing, but it would take usage and bug reports (and help) for all this to get implemented.
You can see some other examples in the test files:
tested examples from a text book https://github.com/sympy/sympy/blob/master/sympy/diffgeom/tests/test_function_diffgeom_book.py
more tests https://github.com/sympy/sympy/blob/master/sympy/diffgeom/tests/test_diffgeom.py
However for doing what is suggested in your question, doing it through differential geometry (while possible) would be an overkill. You can just use the matrices module:
def gradient(expr, vars):
return Matrix([expr.diff(v) for v in vars])
More fancy things like matrix jacobians and more are implemented.
A final remark: using expressions instead of functions and lambdas will probably result in more readable and idiomatic sympy code (often it is more natural to use subs to substitute a symbols instead of some kind of closure, lambda, function call, etc).

Kernel SVM primal with Stochastic Gradient Descent

In short: I am currently reading Online Learning with Kernels (http://books.nips.cc/papers/files/nips14/AA33.pdf) for fun and I can't figure out how he got to equation 8 from equations 6 and 7.
The idea is: We want to minimize a risk function
$R_stoch\[f,t\]:=c(x_t,y_t,f(x_t))+\lambda\Omega\[f\]$
If we want apply the representer theorem on f, writing it as
$f(x)=\sum\alpha_i k(x,x_i)$
how can we get to the STOCHASTIC gradient descent update?
A set of k(xi, x) seems to form a basis of H, and since f is in H, then f can be written as a linear combination of "kernel functions".
So pretending set of k(xi, x) forms a basis of H, it's obvious that if we have some linear combation of the left-hand side and another on the right-hand side, and they're equal, then their basis vector coefficients should be equal too (it's well-known fact from linear algebra that vector equality means vector coefficients (in the same basis!) equality).

Resources