File:  [CENS] / python / pyGiNaC / wrappers / function.py
Revision 1.4: download - view: text, annotated - select for diffs - revision graph
Tue Apr 17 22:39:25 2001 UTC (16 years, 7 months ago) by pearu
Branches: MAIN
CVS tags: HEAD
Fixed bugs. Exposed/checked functions. Started testing framework.

# This file is part of the PyGiNaC package.
# http://cens.ioc.ee/projects/pyginac/
#
# $Revision: 1.4 $
# $Id: function.py,v 1.4 2001-04-17 22:39:25 pearu Exp $
#
# Copyright 2001 Pearu Peterson all rights reserved,
# Pearu Peterson <pearu@cens.ioc.ee>
# Permission to use, modify, and distribute this software is given under the
# terms of the LGPL.  See http://www.fsf.org
#
# NO WARRANTY IS EXPRESSED OR IMPLIED.  USE AT YOUR OWN RISK.
#


depends = ['basic']

uses = ['exvector','ex','numeric','relational']

#STARTPROTO

class function(exprseq):
    """Function images."""
    def __init__(self,*args):
        """func(*args) - construct image of a function func.
    func - builtin GiNaC function or defined by build_function().
    """
    def __coerce__(self,other):
        pass


#ENDPROTO

#STARTDOCTEST
def test_04_function():
    """
To create a GiNaC function, use build_function:
>>> from ginac import build_function

To make a function with zero parameters:
>>> foo = build_function('foo')
>>> foo()
foo()
>>> print foo()
foo()

To make a function with one parameter:
>>> foo = build_function('foo',nofargs=1)
>>> foo(3)
foo(numeric('3'))
>>> print foo(3)
foo(3)

etc. Maximum number of parameters is 9:
>>> bar = build_function('bar',nofargs=9)
>>> print bar(1,2,3,4,5,6,7,8,9)
bar(1, 2, 3, 4, 5, 6, 7, 8, 9)

An exception is raised, when function is used with the wrong
number of parameters:
>>> foo(3,4)
Traceback (most recent call last):
...
TypeError: foo() requires 1 arguments (2 given)

For automatic evaluation use predefined evaluation function:
>>> from ginac import unex, numeric
>>> def gun_eval(x):
...     x = unex(x)
...     if isinstance(x, numeric) and x.is_negative():
...          return 0
...     return gun(x).hold()
>>> gun = build_function('gun',eval_f=gun_eval)
>>> print gun(-1), gun(3)
0 gun(3)

Note the use of hold() method which stops futher automatic
evaluation of the function. You must use it, otherwise the evaluation
function is called recursively until something nasty happens
(recursion depth is reached, out of memory error,...).

Similary, you can submit functions for numerical evaluation (evalf_f,
must take nofargs arguments), for derivative (derivative_f, must take
nofargs+1 arguments), and for series expansion (series_f, must take
nofargs+3 arguments). For example,
>>> def sun_derivative(x,i):
...     # return derivative in respect to the i-th parameter.
...     assert int(unex(i))==0
...     x = unex(x)
...     if isinstance(x, numeric) and x.is_negative():
...          return 0
...     return x*sun(x)
>>> sun = build_function('sun', derivative_f=sun_derivative)
>>> print sun(3)
sun(3)
>>> from ginac import symbol
>>> z = symbol('z')
>>> print sun(z).diff(z)
sun(z) * z
>>> print sun(z).diff(z,2)
sun(z) + sun(z) * z ** 2

Note that build_function is able to determine the number of parameters
automatically from the given functions eval_f, evalf_f, derivative_f, series_f.
But in some occasions it is impossible. Therefore, giving the nofargs explicitly
is always recommended.


"""
#ENDDOCTEST

wrapperclass = '''
std::map<unsigned, python::ref> function_w_eval_cache;
std::map<unsigned, python::ref> function_w_evalf_cache;
std::map<unsigned, python::ref> function_w_derivative_cache;
std::map<unsigned, python::ref> function_w_series_cache;

class function_w : public GiNaC::function {
  PyObject * self;
public:
  function_w(python::ref other) {
    DEBUG_C("function_w::function_w(raw:ref)");
    PyErr_SetString(PyExc_NotImplementedError, "function(raw:ref)");
    throw python::error_already_set();    
  }
  function_w(PyObject * self_): GiNaC::function(), self(self_) {
    DEBUG_C("function_w::function_w()");
  }
  function_w(PyObject * self_, const GiNaC::function & other)
  : GiNaC::function(other), self(self_) {
    DEBUG_C("function_w::function_w(function)");
  }
  function_w(PyObject * self_, const GiNaC::ex & other)
  : GiNaC::function(ex_to_function_w(other)), self(self_) {
    DEBUG_C("function_w::function_w(ex)");
  }
'''
for i in range(nof_function_params+1):
    a = ', '.join(['python::ref a%s'%j for j in range(i)]+['unsigned ser'])
    b = ', '.join(['ex_w(a%s)'%j for j in range(i)] + ['GiNaC::pyfunc(ser)'])
    wrapperclass += '''
  function_w(PyObject * self_, %s)
  : GiNaC::function(ser, %s), self(self_) {
    DEBUG_C("function_w::function_w("<<ser<<",%s*ref)");
    if (registered_functions().size() <= ser) {
      PyErr_SetString(PyExc_TypeError,"function() invalid serial argument (out of bounds)");
      throw python::error_already_set();
    }
    const std::map<unsigned, unsigned>::const_iterator viter = function_w_ser_cache.find(ser);
    if (viter==function_w_ser_cache.end()) {
      PyErr_SetString(PyExc_TypeError,"function() invalid serial argument (builtin function)");
      throw python::error_already_set();
    }    
    if (registered_functions()[ser].get_nparams() != %s) {
      cout << "nparams=" << registered_functions()[ser].get_nparams() << std::endl;
      PyErr_SetString(PyExc_TypeError,"function() invalid number of parameters (expected %s)");
      throw python::error_already_set();
    }
  }
  '''%(a,b,i,i+1,i+1)

wrapperclass += '''
  ~function_w() {
    DEBUG_C("function_w::~function_w()");
  }

  static unsigned build_function(PyObject * args, PyObject * kws);

  GiNaC::ex hold_w(void) {
    DEBUG_M("function.hold()");
    this->hold();
    return *this;
  }
  unsigned nops_w (void) const {
    unsigned n = this->nops();
    if (n==0) return n;
    if (is_ex_exactly_of_type(this->op(n-1), pyfunc))
      return n-1;
    return n;
  }
  GiNaC::ex op_w(int i) const {
    if ((i>=(int)this->nops_w()) || (i<0)) {
      PyErr_SetString(PyExc_IndexError, "function.op index out of range");
      throw python::error_already_set();
    }
    return this->op(i);
  }
private:
'''

for k in ['eval','evalf','derivative','series']:
    for i in range(nof_function_params+1):
        a = ', '.join(['const GiNaC::ex & a%s'%j for j in range(i)]+['const GiNaC::ex & ser'])
        b = '\n'.join(['    args.set_item(%s,UNEX(a%s));'%(j,j) for j in range(i)])
        num = str(i)
        num2 = num
        if k=='derivative':
            a += ', unsigned n'
            b += '\n   args.set_item(%s, n);'%(i)
            num2 = str(i+1)
        if k=='series':
            a += ', const GiNaC::relational & rel, int order, unsigned opt'
            b += '\n   args.set_item(%s, UNEX(GiNaC::ex(rel)));'%(i)
            b += '\n   args.set_item(%s, order);'%(i+1)
            b += '\n   args.set_item(%s, opt);'%(i+2)
            num2 = str(i+3)
        wrapperclass += '''
  static GiNaC::ex #name#_func#num#_w(#args#) {
    DEBUG_M("#name#_func#num#_w()");
    python::tuple args(#num2#);
    if (!is_ex_exactly_of_type(ser,pyfunc)) {
      PyErr_SetString(PyExc_RuntimeError,"#name#_func#num#() last argument must be pyfunc object");
      throw python::error_already_set();
    }
    unsigned ser_ = GiNaC::ex_to_pyfunc(ser).get_serial();
    DEBUG_M("ser_="<<ser_);
    const std::map<unsigned, python::ref>::const_iterator viter = function_w_#name#_cache.find(ser_);
    if (viter==function_w_#name#_cache.end()) {
      PyErr_SetString(PyExc_RuntimeError,"#name#_func#num#() failed to find #name# function in cache");
      throw python::error_already_set();
    }
    #setitem#
    PyObject * ret = NULL;
    ret = PyObject_CallObject(viter->second.get(), args.get());
    if (ret==NULL)
      throw python::error_already_set();
    return ex_w(python::ref(ret));
  }
'''.replace('#name#',k).replace('#num#',num).replace('#num2#',num2).replace('#args#',a).replace('#setitem#',b)

wrapperclass += '''
};
'''

protos = '''
'''

builder = '''
python::class_builder<function_w> function_w_class(this_module, "_function_w");
python::class_builder<GiNaC::function, function_w> function_class(this_module, "_function");
function_class.declare_base(function_w_class);
function_class.declare_base(basic_class);
function_py_class = python::as_object(function_class.get_extension_class());
'''

constructors = '''
function_class.def(python::constructor<>());
function_class.def(python::constructor<const GiNaC::function &>());
function_class.def(python::constructor<const GiNaC::ex &>());
'''
for i in range(nof_function_params+1):
    a = ', '.join(i*['python::ref']+['unsigned'])
    constructors += '''
function_class.def(python::constructor<%s>());'''%(a)

defs = '''
function_class.def(&basic_w::python_str, "__str__");
function_class.def(&basic_w::python_repr, "__repr__");

function_class.def(&function_w::hold_w, "hold");
function_class.def(&function_w::nops_w, "nops");
function_class.def(&function_w::op_w, "op");

this_module.def_raw(&function_w::build_function, "_build_function");


function_class.def(&basic_w::coerce, "__coerce__");
BASIC_OPS(function)

'''

implementation = '''
EX_TO_BASIC(function)

/*
static int get_nof_args(PyObject* fun) {
  PyObject * tmp = NULL;
  int res;
  if (PyCallable_Check(fun))
    if (PyObject_HasAttrString(fun,"func_code"))
      if (PyObject_HasAttrString(tmp = PyObject_GetAttrString(fun,"func_code"),"co_argcount")) {
        res = PyInt_AsLong(PyObject_GetAttrString(tmp,"co_argcount"));
        Py_XDECREF(tmp);
        return res;
      }
  return -1;
}
*/

unsigned function_w::build_function(PyObject * args, PyObject * kws) {
  DEBUG_C("function_w::build_function(string)");
  unsigned ser;
  unsigned ser1 = registered_functions().size();
  DEBUG_C("ser1="<<ser1);
  PyObject * name_py = Py_None;
  PyObject * texname_py = Py_None;
  PyObject * nofargs_py = Py_None;
  PyObject * eval_py = Py_None;
  PyObject * evalf_py = Py_None;
  PyObject * derivative_py = Py_None;
  PyObject * series_py = Py_None;
  static char *kwlist[] = {"nofargs","name","texname","eval","evalf","derivative","series",NULL};
  if (!PyArg_ParseTupleAndKeywords(args,kws,"OO|OOOOO:_ginac._build_function",kwlist,
    &nofargs_py, &name_py, &texname_py, &eval_py, &evalf_py, &derivative_py, &series_py))
    throw python::error_already_set();
  if (!PyInt_Check(nofargs_py)) {
    PyErr_SetString(PyExc_TypeError,"_build_function() first argument must be an int");
    throw python::error_already_set();
  }
  if (!PyString_Check(name_py)) {
    PyErr_SetString(PyExc_TypeError,"_build_function() second argument must be a string");
    throw python::error_already_set();
  }
  GiNaC::function_options opt;
  opt = opt.overloaded(-1);
  if (texname_py == Py_None)
    opt = opt.set_name(PyString_AsString(name_py));
  else {
    if (!PyString_Check(texname_py)) {
      PyErr_SetString(PyExc_TypeError,"_build_function() `texname\' keyword argument must be a string");
      throw python::error_already_set();
    }
    opt = opt.set_name(PyString_AsString(name_py),PyString_AsString(texname_py));
  }
  int nofargs = PyInt_AsLong(nofargs_py);
  if (nofargs>%s || nofargs<0) {
    PyErr_SetString(PyExc_NotImplementedError,"_build_function() first argument must be less than %s");
    throw python::error_already_set();
  }
  opt.test_and_set_nparams(nofargs+1);
'''%(nof_function_params,nof_function_params)

for k in ['eval','evalf']:
    implementation += '''
  if (!(%s_py == Py_None)) {
    if (!PyCallable_Check(%s_py)) {
      PyErr_SetString(PyExc_TypeError,"_build_function() \'%s\' keyword argument must be callable");
      throw python::error_already_set(); 
    }
    switch (nofargs) {\n'''%(k,k,k)
    for i in range(nof_function_params+1):
        implementation += '      case %s: opt = opt.%s_func(%s_func%s_w); break;\n'%(i,k,k,i)
    implementation += '''
    default:
      PyErr_SetString(PyExc_RuntimeError,"_build_function() failure (eval/evalf nofargs bug)");
      throw python::error_already_set(); 
    }
    Py_INCREF(%s_py);
    function_w_%s_cache[ser1] = python::ref(%s_py);
  }
'''%(k,k,k)

if 1:
    implementation += '''
  if (!(derivative_py == Py_None)) {
    if (!PyCallable_Check(derivative_py)) {
      PyErr_SetString(PyExc_TypeError,"_build_function() \'derivative\' keyword argument must be callable");
      throw python::error_already_set(); 
    }
    switch (nofargs) {
'''
    for i in range(nof_function_params+1):
        implementation += '      case %s: opt = opt.derivative_func(derivative_func%s_w); break;\n'%(i,i)
    implementation += '''
    default:
      PyErr_SetString(PyExc_RuntimeError,"_build_function() failure (derivative nofargs bug)");
      throw python::error_already_set(); 
    }
    Py_INCREF(derivative_py);
    function_w_derivative_cache[ser1] = python::ref(derivative_py);
  }
'''

if 1:
    implementation += '''
  if (!(series_py == Py_None)) {
    if (!PyCallable_Check(series_py)) {
      PyErr_SetString(PyExc_TypeError,"_build_function() \'series\' keyword argument must be callable");
      throw python::error_already_set(); 
    }
    switch (nofargs) {\n'''
    for i in range(nof_function_params+1):
        implementation += '      case %s: opt = opt.series_func(series_func%s_w); break;\n'%(i,i)
    implementation += '''
    default:
      PyErr_SetString(PyExc_RuntimeError,"_build_function() failure (series nofargs bug)");
      throw python::error_already_set(); 
    }
    Py_INCREF(series_py);
    function_w_series_cache[ser1] = python::ref(series_py);
  }
'''

implementation += '''  ser = register_new(opt);
  if (ser != ser1) {
    PyErr_SetString(PyExc_RuntimeError, "_build_function() mismatch of function serial numbers");
    throw python::error_already_set();
  }
  function_w_ser_cache[ser]=ser;
  return ser;
}

'''


FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>