File:  [CENS] / python / pyGiNaC / wrappers3 / function_py.cpp
Revision 1.5: download - view: text, annotated - select for diffs - revision graph
Wed May 16 05:49:32 2001 UTC (16 years, 6 months ago) by pearu
Branches: MAIN
CVS tags: HEAD
Added some get_*() methods. Introduced function().

/*
# This file is part of the PyGiNaC package.
# http://cens.ioc.ee/projects/pyginac/
#
# $Revision: 1.5 $
# $Id: function_py.cpp,v 1.5 2001-05-16 05:49:32 pearu Exp $
#
# Copyright 2001 Pearu Peterson all rights reserved,
# Pearu Peterson <pearu@cens.ioc.ee>
# Permission to use, modify, and distribute this software is given under the
# terms of the LGPL.  See http://www.fsf.org
#
# NO WARRANTY IS EXPRESSED OR IMPLIED.  USE AT YOUR OWN RISK.
#
*/
/*DT
  Function
  ------

  To create a GiNaC function, use build_function:
  >>> from ginac import build_function

  To make a function with zero parameters:
  >>> foo = build_function('foo')
  >>> foo()
  foo()
  >>> print foo()
  foo()

  To make a function with one parameter:
  >>> foo = build_function('foo',nofargs=1)
  >>> foo(3)
  foo(numeric('3'))
  >>> print foo(3)
  foo(3)

  etc. Maximum number of parameters is 12:
  >>> bar = build_function('bar',nofargs=12)
  >>> print bar(1,2,3,4,5,6,7,8,9,0,1,2)
  bar(1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2)
  >>> bar2 = build_function('bar2',nofargs=13)
  Traceback (most recent call last):
  ...
  NotImplementedError: _build_function() first argument must be less than 12

  An exception is raised, when function is used with the wrong
  number of parameters:
  >>> foo(3,4)
  Traceback (most recent call last):
  ...
  TypeError: function foo() takes exactly 1 argument (2 given)

  For automatic evaluation use user-defined evaluation function:
  >>> from ginac import numeric
  >>> def gun_eval(x):
  ...     if x.is_negative(): return 0
  ...     return gun(x,hold=1)
  >>> gun = build_function('gun',eval_f=gun_eval)
  >>> print gun(-1), gun(3)
  0 gun(3)

  Important:
  Note the use of the `hold' keyword which stops futher automatic
  evaluation of the function. You **MUST** use it when no more
  evaluation is desired. Otherwise the evaluation function is called
  recursively until something nasty happens (recursion depth is
  reached, out of memory error,...).

  Similary, you can submit functions for numerical evaluation (evalf_f,
  must take nofargs arguments), for derivative (derivative_f, must take
  nofargs+1 arguments), and for series expansion (series_f, must take
  nofargs+3 arguments). For example,
  >>> def sun_derivative(x,i):
  ...     # return derivative in respect to the i-th parameter.
  ...     assert int(i)==0
  ...     if x.is_negative(): return 0
  ...     return x*sun(x)
  >>> sun = build_function('sun', derivative_f=sun_derivative)
  >>> print sun(3)
  sun(3)
  >>> from ginac import symbol
  >>> z = symbol('z')
  >>> print sun(z).diff(z)
  z * sun(z)
  >>> print sun(z).diff(z,2)
  z ** 2 * sun(z) + sun(z)

  Note that build_function is able to determine the number of parameters
  automatically from the given functions eval_f, evalf_f, derivative_f, series_f.
  But in some occasions it is impossible. Therefore, giving the nofargs explicitly
  is always recommended.

  Auxiliary functions:
  >>> from ginac import is_function, sin
  >>> is_function(sin(2))
  1
  >>> sin(2).get_name()
  'sin'
*/

/*F_DT build_function(name,texname=None,eval_f=None,evalf_f=None,derivative_f=None,series_f=None,nofargs=None)
  Construct symbolic function with a name `name'.

  Optional arguments:
  texname      - LaTeX string;
  nofargs      - the number of arguments to a function (if not specified,
                 build_function tries to establish `nofargs' from functions `eval',
		 `evalf', `derivative', `series'. If that fails, 12 is set.);
  eval_f       - eval_f(*args) is used for automatic non-interruptive symbolic
                 evaluation;
  evalf_f      - evalf_f(*args) is used for numeric evaluation;
  derivative_f - derivative_f(*args,i) is used for computing
                 the 1st derivative with respect to `i'-th argument (i>=0);
  series_f     - series_f(*args,rel,order,opts = 0) is used for computing
                 the truncated series expansion around a point `rhs(rel)'
                 with respect to variable `lhs(rel)'. `order' is the truncation
                 order.
 */

#define PYGINAC_CB_FUNC_PROTO(name,num,ARGS) GiNaC::ex name##_func##num##_w ARGS

#ifdef PYGINAC_DEFS
this_module.def(&function_g, "_function");
this_module.def(&pyfunc_1, "_pyfunc");
this_module.def(&ex::is_function, "is_function");
this_module.def(&return_false, "is_function");
ex_class.def(&ex::is_function, "is_function");
ex_class.def(&ex::get_serial, "get_serial");
this_module.def_raw(&build_function, "_build_function");
#else
#ifdef PYGINAC_EX_PROTOS
#define PYGINAC_PROTOS
#endif
#ifdef PYGINAC_PROTOS
#ifdef PYGINAC_EX_PROTOS
bool is_function(void) const;
unsigned get_serial(void) const;
#else // PYGINAC_EX_PROTOS
#include "pyfunc.hpp"
class wrap_function: public GiNaC::function {
public:
  static std::vector<GiNaC::function_options>& registered_functions (void) {
    return GiNaC::function::registered_functions();
  }
};
std::map<unsigned, py::ref> function_w_eval_cache;
std::map<unsigned, py::ref> function_w_evalf_cache;
std::map<unsigned, py::ref> function_w_derivative_cache;
std::map<unsigned, py::ref> function_w_series_cache;
std::map<unsigned, unsigned> function_w_ser_cache;
unsigned build_function(PyObject * args, PyObject * kws);
#include "function_py_subs.cpp"
//GiNaC::ex function_g(unsigned ser, const GiNaC::exvector & v);
GiNaC::ex function_g(unsigned ser, bool hold, const GiNaC::exvector & v);
GiNaC::ex pyfunc_1(unsigned);
BOOST_PYTHON_BEGIN_CONVERSION_NAMESPACE
GiNaC::function from_python (PyObject* o, py::type<const GiNaC::function &>);
BOOST_PYTHON_END_CONVERSION_NAMESPACE
#endif // !PYGINAC_EX_PROTOS
#else  // PYGINAC_PROTOS
BOOST_PYTHON_BEGIN_CONVERSION_NAMESPACE
GiNaC::function from_python (PyObject* o, py::type<const GiNaC::function &>) {
  if (ExInstance_Check(o)) {
    GiNaC::ex e = from_python(o, py::type<const GiNaC::ex &>());
    if (is_ex_exactly_of_type(e, function))
      return ex_to_function(e);
  }
  PYGINAC_FROMPYTHON_TYPEERROR(function,ex(function));
}
BOOST_PYTHON_END_CONVERSION_NAMESPACE
/*F_DT is_function(obj)
  Check if `obj' is function.
*/
/*M_DT is_function(self)
  Check if object is function.
*/
bool ex::is_function(void) const { return is_ex_exactly_of_type(*this, function); }
/*F_DT _build_function(nofargs,name,texname=None,eval=None,evalf=None,derivative=None,series=None)
  Build GiNaC function with Python eval,evalf,derivative,series
  callback functions.	       
  Return the serial number of constructed function.
  Used internally. See build_function().
*/
unsigned build_function(PyObject * args, PyObject * kws) {
  unsigned ser;
  unsigned ser1 = wrap_function::registered_functions().size();
  PyObject * name_py = Py_None;
  PyObject * texname_py = Py_None;
  PyObject * nofargs_py = Py_None;
  PyObject * eval_py = Py_None;
  PyObject * evalf_py = Py_None;
  PyObject * derivative_py = Py_None;
  PyObject * series_py = Py_None;
  static char *kwlist[] = {"nofargs","name","texname",
			   "eval","evalf","derivative","series",NULL};
  if (!PyArg_ParseTupleAndKeywords(args,kws,"OO|OOOOO:_ginac._build_function",kwlist,
				   &nofargs_py, &name_py, &texname_py, &eval_py,
				   &evalf_py, &derivative_py, &series_py))
    throw py::error_already_set();
  if (!PyInt_Check(nofargs_py)) {
    PyErr_SetString(PyExc_TypeError,"_build_function() first argument must be an int");
    throw py::error_already_set();
  }
  if (!PyString_Check(name_py)) {
    PyErr_SetString(PyExc_TypeError,"_build_function() second argument must be a string");
    throw py::error_already_set();
  }
  GiNaC::function_options opt;
  opt = opt.overloaded(-1);
  if (texname_py == Py_None)
    opt = opt.set_name(PyString_AsString(name_py));
  else {
    if (!PyString_Check(texname_py)) {
      PyErr_SetString(PyExc_TypeError,"_build_function() `texname' keyword argument must be a string");
      throw py::error_already_set();
    }
    opt = opt.set_name(PyString_AsString(name_py),PyString_AsString(texname_py));
  }
  int nofargs = PyInt_AsLong(nofargs_py);
  if (nofargs>12 || nofargs<0) {
    PyErr_SetString(PyExc_NotImplementedError,"_build_function() first argument must be less than 12");
    throw py::error_already_set();
  }
  opt.test_and_set_nparams(nofargs+1);
#define PYGINAC_BUILDFUNC_SUB(eval) \
  if (!(eval##_py == Py_None)) { \
    if (!PyCallable_Check(eval##_py)) { \
      PyErr_SetString(PyExc_TypeError,"_build_function() '"#eval"' keyword argument must be callable");\
      throw py::error_already_set(); \
    } \
    switch (nofargs) { \
    case 0: opt = opt.eval##_func(eval##_func0_w); break; \
    case 1: opt = opt.eval##_func(eval##_func1_w); break; \
    case 2: opt = opt.eval##_func(eval##_func2_w); break; \
    case 3: opt = opt.eval##_func(eval##_func3_w); break; \
    case 4: opt = opt.eval##_func(eval##_func4_w); break; \
    case 5: opt = opt.eval##_func(eval##_func5_w); break; \
    case 6: opt = opt.eval##_func(eval##_func6_w); break; \
    case 7: opt = opt.eval##_func(eval##_func7_w); break; \
    case 8: opt = opt.eval##_func(eval##_func8_w); break; \
    case 9: opt = opt.eval##_func(eval##_func9_w); break; \
    case 10: opt = opt.eval##_func(eval##_func10_w); break; \
    case 11: opt = opt.eval##_func(eval##_func11_w); break; \
    case 12: opt = opt.eval##_func(eval##_func12_w); break; \
    default: \
      PyErr_SetString(PyExc_RuntimeError,"_build_function() failure ("#eval" nofargs bug)"); \
      throw py::error_already_set(); \
    } \
    Py_INCREF(eval##_py); \
    function_w_##eval##_cache[ser1] = py::ref(eval##_py); \
  }

  PYGINAC_BUILDFUNC_SUB(eval);
  PYGINAC_BUILDFUNC_SUB(evalf);
  PYGINAC_BUILDFUNC_SUB(derivative);
  PYGINAC_BUILDFUNC_SUB(series);

  ser = wrap_function::register_new(opt);
  if (ser != ser1) {
    PyErr_SetString(PyExc_RuntimeError, "_build_function() mismatch of function serial numbers");
    throw py::error_already_set();
  }
  function_w_ser_cache[ser]=ser;
  return ser;
}

#define PYGINAC_CB_FUNC(name,num,ARGS,num2,setitem) \
PYGINAC_CB_FUNC_PROTO(name,num,ARGS) { \
  py::tuple args(num2); \
  if (!is_ex_exactly_of_type(ser,pyfunc)) { \
    PyErr_SetString(PyExc_RuntimeError,#name "_func"#num "() last argument must be ex(pyfunc)"); \
    throw py::error_already_set(); \
  } \
  unsigned ser_ = GiNaC::ex_to_pyfunc(ser).get_serial(); \
  const std::map<unsigned, py::ref>::const_iterator viter = function_w_##name##_cache.find(ser_); \
  if (viter==function_w_##name##_cache.end()) { \
    std::ostrstream os; \
    os << #name "_func"#num "() failed to find "#name " function in cache "; \
    os << "(ser_="<<ser_<<")" << std::endl; \
    PyErr_SetString(PyExc_RuntimeError,os.str()); \
    throw py::error_already_set(); \
  } \
  setitem \
  PyObject * ret = NULL; \
  ret = PyObject_CallObject(viter->second.get(), args.get()); \
  if (ret==NULL) \
    throw py::error_already_set(); \
  return ex_from_ref(ret); \
}

#include "function_py_subs.cpp"

/*F_DT _function(ser,hold,seq)
  Return image of a function with serial number `ser' and
  arguments `seq'. Last item in `seq' must be ex(pyfunc) object.
  Used internally.
*/
GiNaC::ex function_g(unsigned ser, bool hold, const GiNaC::exvector & v) {
  unsigned l = v.size();
  if (wrap_function::registered_functions().size() <= ser) {
    PyErr_SetString(PyExc_TypeError,"_function() invalid serial argument (out of bounds)");
    throw py::error_already_set();
  }
  unsigned el = wrap_function::registered_functions()[ser].get_nparams();
  if (el != l) {
    std::ostrstream os;
    os << "_function(ser="<<ser<<") invalid number of parameters (expected "<<el<<", got "<<l<<")" << std::ends;
    PyErr_SetString(PyExc_TypeError,os.str());
    throw py::error_already_set();
  }
  const std::map<unsigned, unsigned>::const_iterator viter = function_w_ser_cache.find(ser);
  if (viter!=function_w_ser_cache.end() && !is_ex_exactly_of_type(v[l-1], pyfunc)) {
    std::ostrstream os;
    os << "_function(ser="<<ser<<") last argument must be ex(pyfunc) of user-defined fun"<<std::ends;
    PyErr_SetString(PyExc_TypeError,os.str());
    throw py::error_already_set();
  }
  if (hold)
    return GiNaC::function(ser, v, true).hold();
  return GiNaC::function(ser, v, true);
}

/*F_DT _pyfunc(ser)
  Return ex(pyfunc) holding serial number `ser'.
  Used internally.
*/
GiNaC::ex pyfunc_1(unsigned ser) {
  return GiNaC::pyfunc(ser);
}

/*M_DT get_serial(self)
  Return serial number of a function.
  `self' must be ex(function).
 */
unsigned ex::get_serial(void) const {
  if (is_function())
    return ex_to_function(*this).get_serial();
  PyErr_SetString(PyExc_NotImplementedError, "ex.get_serial() can be used only for ex(function)");
  throw py::error_already_set();
}

#endif // !PYGINAC_PROTOS
#endif // !PYGINAC_DEFS




FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>