Annotation of python/pyGiNaC/wrappers/function.py, revision 1.4

1.1       pearu       1: # This file is part of the PyGiNaC package.
                      2: # http://cens.ioc.ee/projects/pyginac/
                      3: #
1.4     ! pearu       4: # $Revision: 1.3 $
        !             5: # $Id: function.py,v 1.3 2001/04/11 20:31:30 pearu Exp $
1.1       pearu       6: #
                      7: # Copyright 2001 Pearu Peterson all rights reserved,
                      8: # Pearu Peterson <pearu@cens.ioc.ee>
                      9: # Permission to use, modify, and distribute this software is given under the
                     10: # terms of the LGPL.  See http://www.fsf.org
                     11: #
                     12: # NO WARRANTY IS EXPRESSED OR IMPLIED.  USE AT YOUR OWN RISK.
                     13: #
                     14: 
                     15: 
                     16: depends = ['basic']
                     17: 
                     18: uses = ['exvector','ex','numeric','relational']
                     19: 
1.3       pearu      20: #STARTPROTO
                     21: 
                     22: class function(exprseq):
                     23:     """Function images."""
                     24:     def __init__(self,*args):
                     25:         """func(*args) - construct image of a function func.
                     26:     func - builtin GiNaC function or defined by build_function().
                     27:     """
                     28:     def __coerce__(self,other):
                     29:         pass
                     30: 
                     31: 
                     32: #ENDPROTO
                     33: 
1.4     ! pearu      34: #STARTDOCTEST
        !            35: def test_04_function():
        !            36:     """
        !            37: To create a GiNaC function, use build_function:
        !            38: >>> from ginac import build_function
        !            39: 
        !            40: To make a function with zero parameters:
        !            41: >>> foo = build_function('foo')
        !            42: >>> foo()
        !            43: foo()
        !            44: >>> print foo()
        !            45: foo()
        !            46: 
        !            47: To make a function with one parameter:
        !            48: >>> foo = build_function('foo',nofargs=1)
        !            49: >>> foo(3)
        !            50: foo(numeric('3'))
        !            51: >>> print foo(3)
        !            52: foo(3)
        !            53: 
        !            54: etc. Maximum number of parameters is 9:
        !            55: >>> bar = build_function('bar',nofargs=9)
        !            56: >>> print bar(1,2,3,4,5,6,7,8,9)
        !            57: bar(1, 2, 3, 4, 5, 6, 7, 8, 9)
        !            58: 
        !            59: An exception is raised, when function is used with the wrong
        !            60: number of parameters:
        !            61: >>> foo(3,4)
        !            62: Traceback (most recent call last):
        !            63: ...
        !            64: TypeError: foo() requires 1 arguments (2 given)
        !            65: 
        !            66: For automatic evaluation use predefined evaluation function:
        !            67: >>> from ginac import unex, numeric
        !            68: >>> def gun_eval(x):
        !            69: ...     x = unex(x)
        !            70: ...     if isinstance(x, numeric) and x.is_negative():
        !            71: ...          return 0
        !            72: ...     return gun(x).hold()
        !            73: >>> gun = build_function('gun',eval_f=gun_eval)
        !            74: >>> print gun(-1), gun(3)
        !            75: 0 gun(3)
        !            76: 
        !            77: Note the use of hold() method which stops futher automatic
        !            78: evaluation of the function. You must use it, otherwise the evaluation
        !            79: function is called recursively until something nasty happens
        !            80: (recursion depth is reached, out of memory error,...).
        !            81: 
        !            82: Similary, you can submit functions for numerical evaluation (evalf_f,
        !            83: must take nofargs arguments), for derivative (derivative_f, must take
        !            84: nofargs+1 arguments), and for series expansion (series_f, must take
        !            85: nofargs+3 arguments). For example,
        !            86: >>> def sun_derivative(x,i):
        !            87: ...     # return derivative in respect to the i-th parameter.
        !            88: ...     assert int(unex(i))==0
        !            89: ...     x = unex(x)
        !            90: ...     if isinstance(x, numeric) and x.is_negative():
        !            91: ...          return 0
        !            92: ...     return x*sun(x)
        !            93: >>> sun = build_function('sun', derivative_f=sun_derivative)
        !            94: >>> print sun(3)
        !            95: sun(3)
        !            96: >>> from ginac import symbol
        !            97: >>> z = symbol('z')
        !            98: >>> print sun(z).diff(z)
        !            99: sun(z) * z
        !           100: >>> print sun(z).diff(z,2)
        !           101: sun(z) + sun(z) * z ** 2
        !           102: 
        !           103: Note that build_function is able to determine the number of parameters
        !           104: automatically from the given functions eval_f, evalf_f, derivative_f, series_f.
        !           105: But in some occasions it is impossible. Therefore, giving the nofargs explicitly
        !           106: is always recommended.
        !           107: 
        !           108: 
        !           109: """
        !           110: #ENDDOCTEST
        !           111: 
1.1       pearu     112: wrapperclass = '''
                    113: std::map<unsigned, python::ref> function_w_eval_cache;
                    114: std::map<unsigned, python::ref> function_w_evalf_cache;
                    115: std::map<unsigned, python::ref> function_w_derivative_cache;
                    116: std::map<unsigned, python::ref> function_w_series_cache;
                    117: 
                    118: class function_w : public GiNaC::function {
                    119:   PyObject * self;
                    120: public:
                    121:   function_w(python::ref other) {
                    122:     DEBUG_C("function_w::function_w(raw:ref)");
                    123:     PyErr_SetString(PyExc_NotImplementedError, "function(raw:ref)");
                    124:     throw python::error_already_set();    
                    125:   }
                    126:   function_w(PyObject * self_): GiNaC::function(), self(self_) {
                    127:     DEBUG_C("function_w::function_w()");
                    128:   }
                    129:   function_w(PyObject * self_, const GiNaC::function & other)
                    130:   : GiNaC::function(other), self(self_) {
                    131:     DEBUG_C("function_w::function_w(function)");
                    132:   }
                    133:   function_w(PyObject * self_, const GiNaC::ex & other)
                    134:   : GiNaC::function(ex_to_function_w(other)), self(self_) {
                    135:     DEBUG_C("function_w::function_w(ex)");
                    136:   }
                    137: '''
1.2       pearu     138: for i in range(nof_function_params+1):
                    139:     a = ', '.join(['python::ref a%s'%j for j in range(i)]+['unsigned ser'])
                    140:     b = ', '.join(['ex_w(a%s)'%j for j in range(i)] + ['GiNaC::pyfunc(ser)'])
1.1       pearu     141:     wrapperclass += '''
                    142:   function_w(PyObject * self_, %s)
                    143:   : GiNaC::function(ser, %s), self(self_) {
                    144:     DEBUG_C("function_w::function_w("<<ser<<",%s*ref)");
                    145:     if (registered_functions().size() <= ser) {
                    146:       PyErr_SetString(PyExc_TypeError,"function() invalid serial argument (out of bounds)");
                    147:       throw python::error_already_set();
                    148:     }
                    149:     const std::map<unsigned, unsigned>::const_iterator viter = function_w_ser_cache.find(ser);
                    150:     if (viter==function_w_ser_cache.end()) {
                    151:       PyErr_SetString(PyExc_TypeError,"function() invalid serial argument (builtin function)");
                    152:       throw python::error_already_set();
                    153:     }    
                    154:     if (registered_functions()[ser].get_nparams() != %s) {
                    155:       cout << "nparams=" << registered_functions()[ser].get_nparams() << std::endl;
                    156:       PyErr_SetString(PyExc_TypeError,"function() invalid number of parameters (expected %s)");
                    157:       throw python::error_already_set();
                    158:     }
                    159:   }
1.2       pearu     160:   '''%(a,b,i,i+1,i+1)
1.1       pearu     161: 
                    162: wrapperclass += '''
                    163:   ~function_w() {
                    164:     DEBUG_C("function_w::~function_w()");
                    165:   }
1.2       pearu     166: 
1.1       pearu     167:   static unsigned build_function(PyObject * args, PyObject * kws);
                    168: 
                    169:   GiNaC::ex hold_w(void) {
                    170:     DEBUG_M("function.hold()");
                    171:     this->hold();
                    172:     return *this;
                    173:   }
1.2       pearu     174:   unsigned nops_w (void) const {
                    175:     unsigned n = this->nops();
                    176:     if (n==0) return n;
                    177:     if (is_ex_exactly_of_type(this->op(n-1), pyfunc))
                    178:       return n-1;
                    179:     return n;
                    180:   }
                    181:   GiNaC::ex op_w(int i) const {
                    182:     if ((i>=(int)this->nops_w()) || (i<0)) {
                    183:       PyErr_SetString(PyExc_IndexError, "function.op index out of range");
                    184:       throw python::error_already_set();
                    185:     }
                    186:     return this->op(i);
                    187:   }
1.1       pearu     188: private:
                    189: '''
                    190: 
                    191: for k in ['eval','evalf','derivative','series']:
1.2       pearu     192:     for i in range(nof_function_params+1):
                    193:         a = ', '.join(['const GiNaC::ex & a%s'%j for j in range(i)]+['const GiNaC::ex & ser'])
1.4     ! pearu     194:         b = '\n'.join(['    args.set_item(%s,UNEX(a%s));'%(j,j) for j in range(i)])
1.1       pearu     195:         num = str(i)
                    196:         num2 = num
                    197:         if k=='derivative':
                    198:             a += ', unsigned n'
                    199:             b += '\n   args.set_item(%s, n);'%(i)
                    200:             num2 = str(i+1)
                    201:         if k=='series':
                    202:             a += ', const GiNaC::relational & rel, int order, unsigned opt'
1.4     ! pearu     203:             b += '\n   args.set_item(%s, UNEX(GiNaC::ex(rel)));'%(i)
1.1       pearu     204:             b += '\n   args.set_item(%s, order);'%(i+1)
                    205:             b += '\n   args.set_item(%s, opt);'%(i+2)
                    206:             num2 = str(i+3)
                    207:         wrapperclass += '''
                    208:   static GiNaC::ex #name#_func#num#_w(#args#) {
                    209:     DEBUG_M("#name#_func#num#_w()");
                    210:     python::tuple args(#num2#);
                    211:     if (!is_ex_exactly_of_type(ser,pyfunc)) {
1.2       pearu     212:       PyErr_SetString(PyExc_RuntimeError,"#name#_func#num#() last argument must be pyfunc object");
1.1       pearu     213:       throw python::error_already_set();
                    214:     }
1.2       pearu     215:     unsigned ser_ = GiNaC::ex_to_pyfunc(ser).get_serial();
1.1       pearu     216:     DEBUG_M("ser_="<<ser_);
                    217:     const std::map<unsigned, python::ref>::const_iterator viter = function_w_#name#_cache.find(ser_);
                    218:     if (viter==function_w_#name#_cache.end()) {
                    219:       PyErr_SetString(PyExc_RuntimeError,"#name#_func#num#() failed to find #name# function in cache");
                    220:       throw python::error_already_set();
                    221:     }
                    222:     #setitem#
                    223:     PyObject * ret = NULL;
                    224:     ret = PyObject_CallObject(viter->second.get(), args.get());
                    225:     if (ret==NULL)
                    226:       throw python::error_already_set();
                    227:     return ex_w(python::ref(ret));
                    228:   }
                    229: '''.replace('#name#',k).replace('#num#',num).replace('#num2#',num2).replace('#args#',a).replace('#setitem#',b)
                    230: 
                    231: wrapperclass += '''
                    232: };
                    233: '''
                    234: 
                    235: protos = '''
                    236: '''
                    237: 
                    238: builder = '''
                    239: python::class_builder<function_w> function_w_class(this_module, "_function_w");
                    240: python::class_builder<GiNaC::function, function_w> function_class(this_module, "_function");
                    241: function_class.declare_base(function_w_class);
                    242: function_class.declare_base(basic_class);
                    243: function_py_class = python::as_object(function_class.get_extension_class());
                    244: '''
                    245: 
                    246: constructors = '''
                    247: function_class.def(python::constructor<>());
                    248: function_class.def(python::constructor<const GiNaC::function &>());
                    249: function_class.def(python::constructor<const GiNaC::ex &>());
                    250: '''
1.2       pearu     251: for i in range(nof_function_params+1):
                    252:     a = ', '.join(i*['python::ref']+['unsigned'])
1.1       pearu     253:     constructors += '''
                    254: function_class.def(python::constructor<%s>());'''%(a)
                    255: 
                    256: defs = '''
1.2       pearu     257: function_class.def(&basic_w::python_str, "__str__");
                    258: function_class.def(&basic_w::python_repr, "__repr__");
                    259: 
1.1       pearu     260: function_class.def(&function_w::hold_w, "hold");
1.2       pearu     261: function_class.def(&function_w::nops_w, "nops");
                    262: function_class.def(&function_w::op_w, "op");
1.1       pearu     263: 
                    264: this_module.def_raw(&function_w::build_function, "_build_function");
                    265: 
1.4     ! pearu     266: 
1.1       pearu     267: function_class.def(&basic_w::coerce, "__coerce__");
1.4     ! pearu     268: BASIC_OPS(function)
        !           269: 
1.1       pearu     270: '''
                    271: 
                    272: implementation = '''
                    273: EX_TO_BASIC(function)
                    274: 
                    275: /*
                    276: static int get_nof_args(PyObject* fun) {
                    277:   PyObject * tmp = NULL;
                    278:   int res;
                    279:   if (PyCallable_Check(fun))
                    280:     if (PyObject_HasAttrString(fun,"func_code"))
                    281:       if (PyObject_HasAttrString(tmp = PyObject_GetAttrString(fun,"func_code"),"co_argcount")) {
                    282:         res = PyInt_AsLong(PyObject_GetAttrString(tmp,"co_argcount"));
                    283:         Py_XDECREF(tmp);
                    284:         return res;
                    285:       }
                    286:   return -1;
                    287: }
                    288: */
                    289: 
                    290: unsigned function_w::build_function(PyObject * args, PyObject * kws) {
                    291:   DEBUG_C("function_w::build_function(string)");
                    292:   unsigned ser;
                    293:   unsigned ser1 = registered_functions().size();
                    294:   DEBUG_C("ser1="<<ser1);
                    295:   PyObject * name_py = Py_None;
                    296:   PyObject * texname_py = Py_None;
                    297:   PyObject * nofargs_py = Py_None;
                    298:   PyObject * eval_py = Py_None;
                    299:   PyObject * evalf_py = Py_None;
                    300:   PyObject * derivative_py = Py_None;
                    301:   PyObject * series_py = Py_None;
                    302:   static char *kwlist[] = {"nofargs","name","texname","eval","evalf","derivative","series",NULL};
1.2       pearu     303:   if (!PyArg_ParseTupleAndKeywords(args,kws,"OO|OOOOO:_ginac._build_function",kwlist,
1.1       pearu     304:     &nofargs_py, &name_py, &texname_py, &eval_py, &evalf_py, &derivative_py, &series_py))
                    305:     throw python::error_already_set();
                    306:   if (!PyInt_Check(nofargs_py)) {
                    307:     PyErr_SetString(PyExc_TypeError,"_build_function() first argument must be an int");
                    308:     throw python::error_already_set();
                    309:   }
                    310:   if (!PyString_Check(name_py)) {
                    311:     PyErr_SetString(PyExc_TypeError,"_build_function() second argument must be a string");
                    312:     throw python::error_already_set();
                    313:   }
                    314:   GiNaC::function_options opt;
1.4     ! pearu     315:   opt = opt.overloaded(-1);
1.1       pearu     316:   if (texname_py == Py_None)
                    317:     opt = opt.set_name(PyString_AsString(name_py));
                    318:   else {
                    319:     if (!PyString_Check(texname_py)) {
1.2       pearu     320:       PyErr_SetString(PyExc_TypeError,"_build_function() `texname\' keyword argument must be a string");
1.1       pearu     321:       throw python::error_already_set();
                    322:     }
                    323:     opt = opt.set_name(PyString_AsString(name_py),PyString_AsString(texname_py));
                    324:   }
                    325:   int nofargs = PyInt_AsLong(nofargs_py);
1.2       pearu     326:   if (nofargs>%s || nofargs<0) {
                    327:     PyErr_SetString(PyExc_NotImplementedError,"_build_function() first argument must be less than %s");
1.1       pearu     328:     throw python::error_already_set();
                    329:   }
                    330:   opt.test_and_set_nparams(nofargs+1);
1.2       pearu     331: '''%(nof_function_params,nof_function_params)
1.1       pearu     332: 
                    333: for k in ['eval','evalf']:
                    334:     implementation += '''
                    335:   if (!(%s_py == Py_None)) {
1.2       pearu     336:     if (!PyCallable_Check(%s_py)) {
                    337:       PyErr_SetString(PyExc_TypeError,"_build_function() \'%s\' keyword argument must be callable");
                    338:       throw python::error_already_set(); 
                    339:     }
                    340:     switch (nofargs) {\n'''%(k,k,k)
                    341:     for i in range(nof_function_params+1):
1.1       pearu     342:         implementation += '      case %s: opt = opt.%s_func(%s_func%s_w); break;\n'%(i,k,k,i)
                    343:     implementation += '''
                    344:     default:
1.2       pearu     345:       PyErr_SetString(PyExc_RuntimeError,"_build_function() failure (eval/evalf nofargs bug)");
1.1       pearu     346:       throw python::error_already_set(); 
                    347:     }
1.2       pearu     348:     Py_INCREF(%s_py);
                    349:     function_w_%s_cache[ser1] = python::ref(%s_py);
1.1       pearu     350:   }
1.2       pearu     351: '''%(k,k,k)
1.1       pearu     352: 
                    353: if 1:
                    354:     implementation += '''
1.2       pearu     355:   if (!(derivative_py == Py_None)) {
                    356:     if (!PyCallable_Check(derivative_py)) {
                    357:       PyErr_SetString(PyExc_TypeError,"_build_function() \'derivative\' keyword argument must be callable");
                    358:       throw python::error_already_set(); 
                    359:     }
                    360:     switch (nofargs) {
                    361: '''
                    362:     for i in range(nof_function_params+1):
                    363:         implementation += '      case %s: opt = opt.derivative_func(derivative_func%s_w); break;\n'%(i,i)
1.1       pearu     364:     implementation += '''
                    365:     default:
1.2       pearu     366:       PyErr_SetString(PyExc_RuntimeError,"_build_function() failure (derivative nofargs bug)");
1.1       pearu     367:       throw python::error_already_set(); 
                    368:     }
1.2       pearu     369:     Py_INCREF(derivative_py);
                    370:     function_w_derivative_cache[ser1] = python::ref(derivative_py);
1.1       pearu     371:   }
                    372: '''
                    373: 
                    374: if 1:
                    375:     implementation += '''
1.2       pearu     376:   if (!(series_py == Py_None)) {
                    377:     if (!PyCallable_Check(series_py)) {
                    378:       PyErr_SetString(PyExc_TypeError,"_build_function() \'series\' keyword argument must be callable");
                    379:       throw python::error_already_set(); 
                    380:     }
                    381:     switch (nofargs) {\n'''
                    382:     for i in range(nof_function_params+1):
                    383:         implementation += '      case %s: opt = opt.series_func(series_func%s_w); break;\n'%(i,i)
1.1       pearu     384:     implementation += '''
                    385:     default:
1.2       pearu     386:       PyErr_SetString(PyExc_RuntimeError,"_build_function() failure (series nofargs bug)");
1.1       pearu     387:       throw python::error_already_set(); 
                    388:     }
1.2       pearu     389:     Py_INCREF(series_py);
                    390:     function_w_series_cache[ser1] = python::ref(series_py);
1.1       pearu     391:   }
                    392: '''
                    393: 
1.2       pearu     394: implementation += '''  ser = register_new(opt);
                    395:   if (ser != ser1) {
                    396:     PyErr_SetString(PyExc_RuntimeError, "_build_function() mismatch of function serial numbers");
1.1       pearu     397:     throw python::error_already_set();
                    398:   }
                    399:   function_w_ser_cache[ser]=ser;
                    400:   return ser;
                    401: }
                    402: 
                    403: '''
1.2       pearu     404: 

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>