File:  [CENS] / python / pyGiNaC / lib4 /
Revision 1.19: download - view: text, annotated - select for diffs - revision graph
Sat Dec 22 14:07:23 2001 UTC (15 years, 11 months ago) by pearu
Branches: MAIN
CVS tags: HEAD
*** empty log message ***

#!/usr/bin/env python
Symbolic manipulations in Python (with GiNaC support).
  This file is part of the PyGiNaC package.

  $Revision: 1.19 $
  $Id:,v 1.19 2001-12-22 14:07:23 pearu Exp $
  Copyright 2001 Pearu Peterson all rights reserved,
  Pearu Peterson <>
  Permission to use, modify, and distribute this software is given under the
  terms of the LGPL.  See

__author__ = "Pearu Peterson <>"
__license__ = "LGPL (see"
from __version__ import __version__

__all__ = ['ex']

import types
import re
import sys

if re.match (r'.*pydoc',sys.argv[0]):
    _ginac_ok = 0
    import _ginac
    _ginac_ok = 1

class ex:
    Base class for expressions.

    _derived_classes = {} # holds classes derived from ex

    def __init__(self,arg = None):
        """Construct ex instance.
        ex() -> numeric(0)
        ex(ex()) -> ex()
        t = type(arg)
        if t is types.InstanceType:
            if isinstance(arg,ex):
                self.ex = arg.ex
                self.__class__ = arg.__class__
                raise TypeError,'ex(instance): '+arg.__class__.__name__
        elif t is _ginac.ex:
            self.ex = arg
        elif t in [types.IntType,types.FloatType,types.ComplexType]:
            self.ex = _ginac.numeric(arg)
            self.__class__ = numeric
        elif t is types.LongType:
            self.ex = _ginac.numeric(str(arg))
            self.__class__ = numeric
        elif t in [types.ListType,types.TupleType]:
            self.ex = matrix(arg).ex
            self.__class__ = matrix
        elif t is types.StringType:
            raise NotImplementedError,'ex(str), need parser'
        elif arg is None:
            self.ex = _ginac.ex()
            self.__class__ = numeric
            raise TypeError,'ex(): '+`t`

    def _fix_class(self):
        """Choose class corresponding to GiNaC class. Internal."""
        n = self.ex.get_class_name()
            self.__class__ = self._derived_classes[n]
        except KeyError:
            import warnings
            warnings.warn('Not implemented: class '+n)

    def _forget (self):
        """Forget calculated attributes. Internal."""

    def __str__(self):
        return self.ex.to_context_str('python',0)    

    def __repr__(self):
        return self.ex.to_context_str('python_repr',0)

    def to_context(self, context='python', file=None, level=0):
        """Print ex instance to fileobj using GiNaC context.
        If fileobj is None, the result is returned as a string.
        Available contexts:
          python                 (python-pretty-print)
          python_repr            (python-parsable)
          context                (ginsh-parsable)
          latex                  (latex-parsable)
          csrc|csrc_float|csrc_double|csrc_cl_N (C-source)
          tree                   (for debugging)
        level is applicable only for context tree.
        if file is None:
            return self.ex.to_context_str(context, level)
            self.ex.to_context_file(file, context, level)

    def extract_archived(self, level=-1):
        """Return dictionary representation of an archived ex
        The values of returned dictionary are lists of
          integer     - representing boolean
          long        - representing unsigned
          string      - representing string
          dictionary  - representing member ex instance.
          ex object   - member raw ex instance (only if level>=0)
        or a singel item of the above.

        level determines the deepness of the representation.
        level<0 means infinite deepness.
        return self.ex.extract_archived(level)

    def get_precedence(self):
        """Return relative operator precedence (for parenthizing output)."""
        return self.ex.get_precedence()

    def get_class_name(self):
        """Return class name."""
        return self.ex.get_class_name()

    def get_hash(self):
        """Return hash value."""
        return self.ex.get_hash()

    def nops(self):
        """Number of operands/members."""
        return self.ex.nops()

    def op(self, i):
        """Return operand/member at position i."""
        return ex(self.ex.op(i))

    def to_list(self):
        """Return list of operants."""
        return [ex(self.ex.op(i)) for i in range(self.ex.nops())]

    def swap(self, other):
        """Efficiently swap the contents of two expressions."""
        if not isinstance (other,ex):
            raise TypeError,'swap() argument must be ex instance.'
        c = self.__class__
        self.__class__ = other.__class__
        other.__class__ = c

    def eval(self, level=0):        
        """Perform automatic non-interruptive term rewriting rules."""
        return ex(self.ex.eval(level))

    def evalf(self, level=0):
        """Evaluate object numerically."""
        return ex(self.ex.evalf(level))

    def evalm(self):
        """Evaluate sums, products and integer powers of matrices."""
        return ex(self.ex.evalm())

    def has(self, pattern):
        """Test for occurrence of a pattern."""
        return self.ex.has(ex(pattern).ex)

    def to_rational(self, repl_list = None):
        """Rationalization of non-rational functions.

        This function converts a general expression to a rational
        polynomial by replacing all non-rational subexpressions (like
        non-rational numbers, non-integer powers or functions like
        sin(), cos() etc.) to temporary symbols. This makes it
        possible to use functions like gcd() and divide() on
        non-rational functions by applying to_rational() on the
        arguments, calling the desired function and re-substituting
        the temporary symbols in the result. To make the last step
        possible, all temporary symbols and their associated
        expressions are collected in the list specified by the
        repl_lst parameter in the form [symbol == expression], ready
        to be passed as an argument to ex::subs().
        if isinstance(repl_list, lst):
            repl_lst_tmp = repl_lst
        elif type(repl_list) is types.ListType:
            repl_lst_tmp = lst(*repl_list)
        elif repl_list is None:
            repl_lst_tmp = lst()
            repl_list = []
            raise TypeError,'to_rational() argument must be list|lst|None'
        result = ex(self.ex.to_rational(repl_lst_tmp.ex))
        if isinstance(repl_list, lst):
            return result,repl_list
        while repl_list: repl_lst.pop()
        [repl_list.append(e) for e in repl_lst_tmp.to_list()]
        return result,repl_list

    def match(self,pattern,repl_list = None):
        """Check whether the expression matches a given pattern.

        Return (b,repl_list) where b is non-zero if any matches were
        found, and repl_list is a list of found matches (that can be
        given also as a second argument).
        if isinstance(repl_list, lst):
            repl_lst_tmp = repl_lst
        elif type(repl_list) is types.ListType:
            repl_lst_tmp = lst(*repl_list)
        elif repl_list is None:
            repl_lst_tmp = lst()
            repl_list = []
            raise TypeError,'match() 2nd argument must be list|lst|None'
        b = self.ex.match(ex(pattern).ex,repl_lst_tmp.ex)
        if isinstance(repl_list, lst):
            return b,repl_list
        while repl_list: repl_lst.pop()
        [repl_list.append(e) for e in repl_lst_tmp.to_list()]
        return b,repl_list

    def find(self,pattern,repl_list = None):
        """Find all occurrences of a pattern.

        Return (b,repl_list) where b is non-zero if any matches were
        found, and repl_list is a list of found matches (that can be
        given also as a second argument).
        if isinstance(repl_list, lst):
            repl_lst_tmp = repl_lst
        elif type(repl_list) is types.ListType:
            repl_lst_tmp = lst(*repl_list)
        elif repl_list is None:
            repl_lst_tmp = lst()
            repl_list = []
            raise TypeError,'find() 2nd argument must be list|lst|None'
        b = self.ex.find(ex(pattern).ex,repl_lst_tmp.ex)
        if isinstance(repl_list, lst):
            return b,repl_list
        while repl_list: repl_lst.pop()
        [repl_list.append(e) for e in repl_lst_tmp.to_list()]
        return b,repl_list

    def subs(self, *rels):
        """Substitute a set of objects by arbitrary expressions.

        Arguments must be object==expression. Or a single list of such
        relational instances.
        if len (rels)==1:
            if type(rels[0]) is types.ListType:
                rels[0] = lst (*rels[0])
            if (isinstance(rels[0], relational) or isinstance(rels[0], lst)):
                return ex(self.ex.subs(rels[0].ex))
        return ex(self.ex.subs(lst(*rels).ex))

    def is_function(self,name=None,nofargs=None):
        """Test for function with a name and nof arguments."""
        return 0

    def __nonzero__(self):
        """Test for non-zero."""
        return not self.ex.is_zero()

    def is_equal(self, other):
        """Test for equality."""
        return self.ex.is_equal(ex(other).ex)

    def is_zero(self):
        """Test for zero."""
        return self.ex.is_zero()

    def expand(self):
        """Expand expression (multiply out)."""
        return ex(self.ex.expand(0))

    def collect(self,*objs,**kws):
        """Sort expanded expression in terms of powers of some objects.
        The following keywords are recognized:
          recursive=1       (default)
        that, having opposite effects, indicate the corresponding form
        of collect.
        flag = kws.get('distributed',not kws.get('recursive',1))
        if len(objs)==1 and isinstance(objs[0],lst):
            return ex(self.ex.collect(objs[0].ex,flag))
        return ex(self.ex.collect(lst(*objs).ex,flag))

    def normal(self,level=0):
        """Normalization of rational functions.

        Converts an expression to its normal form
        `numerator/denominator', where numerator and denominator are
        (relatively prime) polynomials.  Any subexpressions which are
        not rational functions (like non-rational numbers, non-integer
        powers, or functions like sin(), cos() etc.)  are replaced by
        temporary symbols which are re-substituted by the (normalized)
        subexpressions before normal() returns (this way, any
        expression can be treated as a rational function).  normal()
        is applied recursively to arguments of functions etc.  level
        is maximum depth of recursion.
        return ex(self.ex.normal(level))

    def numer(self):
        """Get numerator of an expression."""
        return ex(self.ex.numer())

    def denom(self):
        """Get denominator of an expression."""
        return ex(self.ex.denom())

    def numer_denom(self):
        """Get (numerator, denominator) of an expression."""
        return tuple(map(ex,self.ex.numer_denom().lst_to_list()))    

class functional_mths:

    def diff(self,s,nth=1):
        """Compute partial derivative of an expression."""
        if isinstance (nth,types.IntType) and nth>=0:
            if isinstance (s,symbol):
                return ex(self.ex.diff(s.ex,nth))
            raise TypeError,'diff() method 1st argument must be symbol'
        raise TypeError,'diff() method 2nd argument must be positive int'
    def series(self,r,order,options=0):
        """Compute the truncated series expansion of an expression."""
        return ex(self.ex.series(r.ex,order,options))

    def __abs__(self):
        return symbolic_abs(self)

    #TODO: expose sin,cos,.. as methods

class degree_mths:

    def degree(self, s):
        """Return degree of highest power in object s."""

    def ldegree(self, s):
        """Return degree of lowest power in object s."""
        return self.ex.ldegree(ex(s).ex)

    def coeff(self, s, n=1):
        """Return coefficient of degree n in object s."""
        if isinstance (n,types.IntType):
            return ex(self.ex.coeff(ex(s).ex,n))    
        raise TypeError,'coeff() method 2nd argument must be int but got %r'%(type (n))
    def lcoeff(self, s):
        """Return leading coefficient in object s."""
        return ex(self.ex.lcoeff(ex(s).ex))

    def tcoeff(self, s):
        """Return trailing coefficient in object s."""
        return ex(self.ex.tcoeff(ex(s).ex))

class polynomial_mths(degree_mths,functional_mths):

    def unit(self,x):
        """Compute unit part (= sign of leading coefficient) of a
        multivariate polynomial in Z[x]."""
        if isinstance (x,symbol):
            return ex(self.ex.unit(x.ex))
        raise TypeError,'unit() method argument must be symbol'

    def content(self,x):
        """Compute content part (= unit normal GCD of all
        coefficients) of a multivariate polynomial in Z[x]."""
        if isinstance (x,symbol):
            return ex(self.ex.content(x.ex))
        raise TypeError,'content() method argument must be symbol'

    def integer_content(self):
        """Compute the integer content (= GCD of all numeric
        coefficients) of an expanded polynomial."""
        return ex(self.ex.integer_content())

    def primpart(self,x):
        """Compute primitive part of a multivariate polynomial in
        if isinstance (x,symbol):
            return ex(self.ex.primpart(x.ex))
        raise TypeError,'primpart() method argument must be symbol'

    def smod(self,xi):
        """Apply symmetric modular homomorphism to a multivariate
        xi = numeric (xi)
        if xi.ex.is_zero():
            raise ZeroDivisionError,'division by zero'
        return ex(self.ex.smod(xi.ex))

    def max_coefficient(self):
        """Return maximum (absolute value) coefficient of a
        return ex(self.ex.max_coefficient())

def _canonic_result(op1,op2,result):
    if isinstance(op1,pseries):
        return op1.get_compatible(result)
    if isinstance(op2,pseries):
        return op2.get_compatible(result)
    return result

class arith_ops:

    def _check (self,other,op):
        if not isinstance (other,arith_ops):
            raise TypeError,'%s %s %s not defined'%(self.get_class_name(),op,other.get_class_name())
    def _rcheck (self,other,op):
        if not isinstance (other,arith_ops):
            raise TypeError,'%s %s %s not defined'%(other.get_class_name(),op,self.get_class_name())
    def __coerce__(self,other):
        other = ex(other)
        return self,other
    def __pos__(self):
        return self
    def __neg__(self):
        return _canonic_result(self,None,ex(-self.ex))
    def __add__(self, other):
        arith_ops._check (self,other,'+')
        return _canonic_result(self,other,ex(self.ex + other.ex))
    def __radd__(self, other):
        arith_ops._rcheck (self,other,'+')
        return _canonic_result(self,other,ex(other.ex + self.ex))
    def __sub__(self, other):
        arith_ops._check (self,other,'-')
        return _canonic_result(self,other,ex(self.ex - other.ex))
    def __rsub__(self, other):
        arith_ops._rcheck (self,other,'-')
        return _canonic_result(self,other,ex(other.ex - self.ex))
    def __mul__(self, other):
        arith_ops._check (self,other,'*')
        return _canonic_result(self,other,ex(self.ex * other.ex))
    def __rmul__(self, other):
        arith_ops._rcheck (self,other,'*')
        return _canonic_result(self,other,ex(other.ex * self.ex))
    def __div__(self, other):
        arith_ops._check (self,other,'/')
        return _canonic_result(self,other,ex(self.ex / other.ex))
    def __rdiv__(self, other):
        arith_ops._rcheck (self,other,'/')
        return _canonic_result(self,other,ex(other.ex / self.ex))
    def __pow__(self, other):
        arith_ops._check (self,other,'**')
        return _canonic_result(self,other,power(self, other))
    def __rpow__(self, other):
        arith_ops._rcheck (self,other,'**')
        return _canonic_result(self,other,power(other, self))

class relat_ops:

    def _check (self,other,rel):
        if not isinstance (other,relat_ops):
            raise TypeError,'%s %s %s not defined'%(self.get_class_name(),rel,other.get_class_name())
    def __lt__(self, other):
        other = ex(other)
        relat_ops._check (self,other,'<')
        return ex(_ginac.relational(self.ex, other.ex, '<'))
    def __gt__(self, other):
        other = ex(other)
        relat_ops._check (self,other,'>')
        return ex(_ginac.relational(other.ex, self.ex, '<'))
    def __le__(self, other):
        other = ex(other)
        relat_ops._check (self,other,'<=')
        return ex(_ginac.relational(self.ex, other.ex, '<='))
    def __ge__(self, other):
        other = ex(other)
        relat_ops._check (self,other,'>=')
        return ex(_ginac.relational(other.ex, self.ex, '<='))
    def __eq__(self, other):
        other = ex(other)
        relat_ops._check (self,other,'==')
        return ex(_ginac.relational(self.ex, other.ex, '=='))
    def __ne__(self, other):
        other = ex(other)
        relat_ops._check (self,other,'!=')
        return ex(_ginac.relational(self.ex, other.ex, '!='))

class numeric(ex,arith_ops,relat_ops,polynomial_mths):
    CLN arbitrary precision number.

      numeric(numer, denom)
    where numer and denom are Python numbers or strings representing
    numbers. Both arguments are optional. E.g.
      numeric ('2/3')
      numeric (2, 3)
      numeric ('2-3I')
      numeric ('2.3e1000000')
    The precision of newly created numeric instances can be changed
    using Digits() function.
    def __init__(self, numer=0, denom=1):
        """Construct numeric instance.
        numeric()                 -> 0
        numeric(numer)            -> numer
        numeric(numer, denom)     -> numer/denom
        if denom is 1 and isinstance(numer,numeric):
            self.ex = numer.ex
        if type(numer) is types.StringType:
            numer = _ginac.numeric(numer)
        numer = ex(numer)
        if denom is 1:
            self.ex = numer.ex
            denom = numeric(denom)
            self.ex = numer.ex / denom.ex
        if not isinstance(numer, numeric):
            self._fix_class ()
    def is_zero(self):
        return self.ex.numeric_is_zero()
    def is_negative(self):
        return self.ex.numeric_is_negative()
    def is_positive(self):
        return self.ex.numeric_is_positive()
    def is_integer(self):
        return self.ex.numeric_is_integer()
    def is_pos_integer(self):
        return self.ex.numeric_is_pos_integer()
    def is_nonneg_integer(self):
        return self.ex.numeric_is_nonneg_integer()
    def is_even(self):
        return self.ex.numeric_is_even()
    def is_odd(self):
        return self.ex.numeric_is_odd()
    def is_prime(self):
        if self.ex.numeric_is_positive():
            return self.ex.numeric_is_prime()
        return 0
    def is_real(self):
        return self.ex.numeric_is_real()
    def is_rational(self):
        return self.ex.numeric_is_rational()
    def is_cinteger(self):
        return self.ex.numeric_is_cinteger()
    def is_crational(self):
        return self.ex.numeric_is_crational()

    def real(self):
        return ex(self.ex.numeric_real())
    def imag(self):
        return ex(self.ex.numeric_imag())
    def numer(self):
        return ex(self.ex.numeric_numer())
    def denom(self):
        return ex(self.ex.numeric_denom())
    def inverse(self):
        if self.ex.is_zero():
            raise ZeroDivisionError,'division by zero'
        return ex(self.ex.numeric_inverse())
    def csgn(self):
        return self.ex.numeric_csgn()
    def int_length(self):
        return self.ex.numeric_int_length()
    def __int__(self):
        if self.ex.numeric_is_integer():
            return self.ex.numeric_to_long()
            return int(self.ex.numeric_to_double())
    def __long__(self):
        return long(str(self))
    def __float__(self):
        return self.ex.numeric_to_double()
    def __complex__(self):
        r,i = self.ex.numeric_real(),self.ex.numeric_imag()
        return complex(r.numeric_to_double(), i.numeric_to_double())
    def __abs__(self):
        return ex(self.ex.numeric_abs())
    def __pow__(self, other):
        if isinstance(other, numeric):
            return ex(self.ex.numeric_power(other.ex))
        raise NotImplementedError,type(other)
    def __rpow__(self, other):
        if isinstance(other, numeric):
            return ex(other.ex.numeric_power(self.ex))
        raise NotImplementedError, type(other)
    def compare(self, other):
        """Establishes canonical order of all numbers.
        Return (self-other).csgn(). (GiNaC internal??)
        other = ex(other)
        if isinstance(other, numeric):
            return self.ex.numeric_compare(other.ex)
        raise NotImplementedError, type(other)
    def is_equal(self, other):
        if isinstance(other, numeric):
            return ex(self.ex.numeric_is_equal(other.ex))
        return ex.is_equal(self, other)
    def __lshift__(self, other):
        """Return self * (2**other)"""
        if isinstance(other, numeric):
            return ex(self.ex.numeric_lshift(other.ex))
        raise NotImplementedError,type(other)
    def __rshift__(self, other):
        """Return floor(self / (2**other))"""
        if isinstance(other, numeric):
            return ex(self.ex.numeric_rshift(other.ex))
        raise NotImplementedError,type(other)
    def __invert__(self):
        return ex(self.ex.numeric_add(1).numeric_mul(-1))
    def conjugate(self):
        return ex(self.ex.numeric_real().numeric_add(self.ex.numeric_imag().numeric_mul(-I.ex)))
    def exp(self): return ex(self.ex.numeric_exp())
    def log(self): return ex(self.ex.numeric_log())
    def sin(self): return ex(self.ex.numeric_sin())
    def cos(self): return ex(self.ex.numeric_cos())
    def tan(self): return ex(self.ex.numeric_tan())
    def asin(self): return ex(self.ex.numeric_asin())
    def acos(self): return ex(self.ex.numeric_acos())
    def atan(self): return ex(self.ex.numeric_atan())
    def sinh(self): return ex(self.ex.numeric_sinh())
    def cosh(self): return ex(self.ex.numeric_cosh())
    def tanh(self): return ex(self.ex.numeric_tanh())
    def asinh(self): return ex(self.ex.numeric_asinh())
    def acosh(self): return ex(self.ex.numeric_acosh())
    def atanh(self): return ex(self.ex.numeric_atanh())
    def Li2(self): return ex(self.ex.numeric_Li2())
    def zeta(self): return ex(self.ex.numeric_zeta())
    def lgamma(self): return ex(self.ex.numeric_lgamma())
    def tgamma(self): return ex(self.ex.numeric_tgamma())
    def psi(self): return ex(self.ex.numeric_psi())
    def factorial(self): return ex(self.ex.numeric_factorial())
    def doublefactorial(self): return ex(self.ex.numeric_doublefactorial())
    def bernoulli(self): return ex(self.ex.numeric_bernoulli())
    def fibonacci(self): return ex(self.ex.numeric_fibonacci())
    def isqrt(self): return ex(self.ex.numeric_isqrt())
    def sqrt(self): return ex(self.ex.numeric_sqrt())
    def abs(self): return ex(self.ex.numeric_abs())
    def atan2(self, other): return ex(self.ex.numeric_atan2(ex(other).ex))
    def psi2(self, other): return ex(self.ex.numeric_psi2(ex(other).ex))
    def mod(self, other):
        other = ex(other)
        if other.is_zero():
            raise ZeroDivisionError,'division by zero'
        return ex(self.ex.numeric_mod(other.ex))
    def smod(self, other):
        other = ex(other)
        if other.is_zero():
            raise ZeroDivisionError,'division by zero'
        return ex(self.ex.numeric_smod(other.ex))
    def binomial(self, other):
        return ex(self.ex.numeric_binomial(ex(other).ex))
    def irem(self, other):
        other = ex(other)
        if other.is_zero():
            raise ZeroDivisionError,'division by zero'
        return ex(self.ex.numeric_irem(other.ex))
    def iquo(self, other):
        other = ex(other)
        if other.is_zero():
            raise ZeroDivisionError,'division by zero'
        return ex(self.ex.numeric_iquo(other.ex))
    def gcd(self, other):
        return ex(self.ex.numeric_gcd(ex(other).ex))
    def lcm(self, other):
        return ex(self.ex.numeric_lcm(ex(other).ex))

ex._derived_classes['numeric'] = numeric

class symbol(ex,arith_ops,relat_ops,polynomial_mths):
    """Basic CAS symbol.

      symbol (name, TeX_name)
    where name and TeX_name are strings, both are optional.
    _cache = {}
    _counter = 0
    _autoname_prefix = 'symbol'
    #_re_texname = re.compile(r'(alpha|beta|gamma|delta|epsilon|varepsilon|zeta|eta|theta|vartheta|iota|kappa|lambda|mu|nu|xi|omicron|pi|varpi|rho|varrho|sigma|varsigma|tau|upsilon|phi|varphi|chi|psi|omega|Gamma|Delta|Theta|Lambda|Xi|Pi|Sigma|Upsilon|Phi|Psi|Omega)([^a-zA-Z]|\b)')

    def __init__(self, name=None, TeX_name=None):
        symbol (symbol (...)) -> symbol (...)
        symbol ()             -> symbol ('symbol\d+')
        symbol ('name','texname')
        symbol ('name')
        cls = self.__class__
        if name is None:
            name = '%s%d'%(cls._autoname_prefix, cls._counter)
            cls._counter += 1
        elif isinstance(name, symbol):
            self.ex = name.ex
            if TeX_name is not None:
                raise ValueError,cls.__name__+' '+`name`+' has a TeX_name'
        elif type(name) is not types.StringType:
            raise TypeError, 'symbol() expects string name but got %s'%(type (name))
            self.ex = cls._cache[name]
            if TeX_name is not None:
                raise ValueError,cls.__name__+' '+`name`+' has a TeX_name'  
        except KeyError:
        if TeX_name is None:
            self.ex = _ginac.symbol(name)
##             TeX_name = name
##             if cls._re_texname.match(TeX_name):
##                 TeX_name = '\\' + TeX_name
        elif type(TeX_name) is types.StringType:
            self.ex = _ginac.symbol(name, TeX_name) 
            raise TypeError, 'symbol() expects string TeX_name but got %s'%(type (TeX_name))       
        cls._cache[name] = self.ex

    def get_name(self):
        """Get symbol name."""
        return self.ex.symbol_get_name()

    def get_TeX_name(self):
        """Get symbol TeX name."""
        return self.to_context('latex')

    def set_name(self, name):
        """Set symbol name."""

ex._derived_classes['symbol'] = symbol

class constant(ex,arith_ops,relat_ops,polynomial_mths):
    Constants, symbols with specific numerical value.

      constant (name, value, TeX_name)
      name     - name of the constant. Required.
      value    - is numerical value of the constant or a function that
                 evaluates that value. It is returned with evalf() method.
      TeX_name - LaTeX representation of the constant name. Optional.
    _cache = {}

    def __init__(self, name, value=None, TeX_name = None):
        constant (name)
        constant (name, number)
        constant (name, lambda : number)
        constant (name, number, texname)
        cls = self.__class__
        if type(name) is not types.StringType:
            raise TypeError, cls.__name__+'() expects string name'
        elif cls._cache.has_key(name):
            if not (value is None and TeX_name is None):
                raise ValueError,cls.__name__+' '+`name`+' has a value or/and TeX_name'
            self.ex = cls._cache[name]
        if TeX_name is None:
            TeX_name = '\\mbox{%s}'%name
        elif type(TeX_name) is not types.StringType:
            raise TypeError, cls.__name__+'() expects string TeX_name but got '+`type(TeX_name)`
        if name in ['Pi','Euler','Catalan']:
            if value is not None:
                raise ValueError,'internal constant '+`name`+' has a value'
            self.ex = getattr(_ginac,name)
            self.ex = _ginac.constant(name, value, TeX_name)
        cls._cache[name] = self.ex

    def _forget (self):
            del self._name
        except AttributeError:
    def get_init(self):
        """Return constant init value without evaluation."""
        r = self.ex.constant_get_init()
        if type(r) is _ginac.ex:
            return ex(r)
        return r
    def get_name(self):
        """Return constant name."""
        if hasattr(self,'_name'):
            return self._name
        d = self.ex.extract_archived(0)
        self._name = d['name']
        return self._name
    def get_TeX_name(self):
        """Return constant TeX name."""
        return self.to_context('latex')

ex._derived_classes['constant'] = constant

class add(ex,arith_ops,relat_ops,polynomial_mths):
    """Sum of expressions.
      add (expr1, expr2, ...)
    def __init__(self, *args):
        add ()           -> 0
        add (expr)       -> expr
        add (e1,e2,...)  -> e1+e2+...
        self.ex = _ginac.add(map(ex,args))

ex._derived_classes['add'] = add

class mul(ex,arith_ops,relat_ops,polynomial_mths):
    """Product of expressions.
      mul (expr1, expr2, ...)
    def __init__(self, *args):
        mul ()           -> 1
        mul (expr)       -> expr
        mul (e1,e2,...)  -> e1*e2*...        
        self.ex = _ginac.mul(map(ex,args))

    def __abs__ (self): # FIXME: this should be implemented in GiNaC, if at all
        terms1 = []
        terms2 = []
        for t in self.to_list():
            t1 = abs(t)
            if t1.is_function('abs',1):
        return symbolic_abs(mul(*terms1))*mul(*terms2)

ex._derived_classes['mul'] = mul

class ncmul(ex,arith_ops,relat_ops,functional_mths):
    """Non-commutative product of expressions.
      ncmul(expr1, expr2, ...)
    def __init__(self, *args):
        self.ex = _ginac.ncmul(map(ex,args))

ex._derived_classes['ncmul'] = ncmul

class power(ex,arith_ops,relat_ops,polynomial_mths):
      power (base, exponent)
    def __init__(self, base, exponent):
        self.ex = _ginac.power(ex(base).ex, ex(exponent).ex)

def sqrt(e):
    """Square root."""
    return power(e,_ex1_2)

ex._derived_classes['power'] = power

class relational(ex):
    """This class holds a relation consisting of two expressions and
    a logical relation between them.
      relational (lhs,rhs,oper)
      lhs == rhs
      lhs < rhs
    where oper can be `==', `!=', `<', `<=', `>', or `>='.
    def __init__(self, rhs, lhs=None, oper='=='):
        relational (relational (...)) -> relational (...)
        relational (lhs, rhs) -> relational (lhs, rhs, '==')
        relational (lhs, rhs, oper)
        if isinstance(rhs, relational):
            if lhs is not None or oper != '==':
                raise ValueError,'relational(relational) expects lhs is None and oper="==" but got (%r,%r)'%(lhs,oper)
            self.ex = rhs.ex
        self.ex = _ginac.relational(ex(rhs).ex, ex(lhs).ex, oper)
    def __nonzero__(self):
        """Return 1 if relation is true.
        Return 0 if relation is either false or undecidable.
        Note that (a<b)==0 does not imply that (a>=b)==1 in the
        general symbolic sense. 
        return self.ex.relational_bool()
    def operator(self):
        """Return string of the logical relation symbol"""
        d = self.ex.extract_archived(0)
        return {0:'==',1:'!=',2:'<',3:'<=',4:'>',5:'>='}[d['op']]
    def lhs(self):
        """Return left hand side of the logical relation"""
        return ex(self.ex.op(0))
    def rhs(self):
        """Return right hand side of the logical relation"""
        return ex(self.ex.op(1))

ex._derived_classes['relational'] = relational

class lst(ex):
    """GiNaC list of expressions.
    def __init__(self,*items):
        lst()               -> []
        lst(expr1,expr2,..) -> [expr1,expr2,...]
        self.ex = _ginac.lst(map(ex,items))
    def to_list(self):
        return map(ex,self.ex.lst_to_list())
    def __len__(self):
        return self.ex.nops()
    def __getitem__(self, index):
        """Get item/sublist of a lst.
        l[i]  - i-th item
        l[slice] - sublist where slice is start:end:step with
        start, end, and step being optional integers (also negative).
        return ex(self.ex.get_slice(index))
    def __setitem__(self, index, other):
        """Set item/sublist of a lst.
        l[i] = e        - set i-th item to e.
        l[slice] = e    - if e.nops()==0, all items in l[slice]
        as set to e. Otherwise e.nops()==l[slice].nops() must hold.
    def __nonzero__(self):
        return self.ex.nops() != 0
    def append(self,other):
        """Append other to end."""
    def prepend(self,other):
        """Prepend other to begin."""
    def sort(self):
        """Sort in place (using compare)."""
    def unique(self):
        """Remove repeated subsequent occurrences."""
    def remove_last(self):
    def remove_first(self):

ex._derived_classes['lst'] = lst

def _sequence_check (obj):
    """Check for sequence."""
        return 1
    except (TypeError,AttributeError):
        return 0

class matrix(ex,arith_ops):
    """Symbolic matrix (elements are stored row-wise).
    def __init__(self,m):
        matrix (matrix (...)) -> matrix (...)
        matrix (seq)
        #TODO: Implement __init__ in C++
        if isinstance(m,matrix):
            self.ex = m.ex
        if not _sequence_check(m):
            raise TypeError,'matrix() argument must be sequence|matrix'
        l = []
        cols = 0
        rows = 0
        for r in m:
            rows += 1
            if _sequence_check(r):
        flat_list = []
        for r in l:
            flat_list += list(r)+(cols-len(r))*[0]
        if not (rows and cols):
            raise ValueError,'matrix(): rows=%s and cols=%s must be positive'%(cols,rows)
        self.ex = _ginac.matrix(rows,cols,flat_list)
    def __getitem__(self,index):
        """Get element/submatrix of a matrix.
        m[i,j]  - element in i-th row and j-th column
        m[j]    - entire j-th column
        m[i,:]  - entire i-th row
        m[slice] - slice of columns
        m[slice,slice] - submatrix
        where slice is in the form start:end:step, with start,end,step
        being all optional (can be negative), including colons between them.
        return ex(self.ex.matrix_get_slice(index))
    def __setitem__ (self,index,other):
        """Set element/submatrix of a matrix.
        m[i,j] = e  - set element m[i,j] to e.
        m[slice,slice] = e    - if e.nops()==0, then all elements
        in m[slice,slice] are set to e. Otherwise, e.nops()==m[slice,slice].nops()
        must hold.
    def get_rows(self):
        """Get number of rows."""
        return self.ex.matrix_rows()
    def get_cols(self):
        """Get number of columns."""
        return self.ex.matrix_cols()
    def trace (self):
        """Trace of a matrix."""
        return ex(self.ex.matrix_trace())
    def inverse(self):
        """Inverse matrix."""
        return ex(self.ex.matrix_inverse())
    def transpose(self):
        """Transposed matrix."""
        return ex(self.ex.matrix_transpose())
    def charpoly(self,s):
        """Characteristic Polynomial."""
        if isinstance (s,symbol):
            return ex(self.ex.matrix_charpoly(s.ex))
        raise TypeError,'charpoly() method argument must be symbol'        
    def determinant(self,algo='automatic'):
        """Determinant of square matrix.
        algo is a switch to control algorithm for determinant computation.
        It must be string automatic|gauss|divfree|laplace|bareiss.
        return ex(self.ex.matrix_determinant(algo))
    def solve(self,rhs,vars=None,algo='automatic'):
        """Solve a linear system consisting of a m x n matrix
        and a m x p right hand side by applying an elimination
        scheme to the augmented matrix.
        vars is None or a n x p matrix of symbols (and only symbols).
        algo is a switch to control algorithm for linear system solving.
        It must be string automatic|gauss|divfree|bareiss.
        rhs = matrix(rhs)
        if vars is None:
            for r in range(self.ex.matrix_cols()):
                row = []
                for c in range(rhs.ex.matrix_cols()):
        vars = matrix (vars)
        return ex(self.ex.matrix_solve(vars.ex,rhs.ex,algo))

ex._derived_classes['matrix'] = matrix

def diag_matrix(dmap):
    """Construct matrix from its diagonals.
    dmap is a dictionary with items (i,l), where
    i is integer refering to i-th diagonal, counted from the main
    diagonal. For the lower diagonals, key i is negative.
    l is a sequence of symbolic expressions in the corresponding diagonal.
    #TODO: implement diagmatrix in C++.
    rows = 0
    cols = 0
    for k,l in dmap.items():
        if type (k) is not types.IntType:
            raise TypeError,'diagmatrix() dict-argument key must be int'
        if not _sequence_check(l):
            l = [l]
            dmap[k] = l
        length = len(l)
        if k>0:
            cols = max (cols,k+length)
            rows = max (rows,length)
            rows = max (rows,-k+length)
            cols = max (cols,length)
    if not (cols and rows):
        raise ValueError,'diag_matrix(): rows=%s and cols=%s must be positive'%(cols,rows)
    flat_list = []
    for r in range(rows):
        for c in range(cols):
            k = c-r
            if k>0: n = r
            else: n = c
            except (KeyError,IndexError):
                flat_list.append (0)
    return ex(_ginac.matrix(rows,cols,flat_list))

__all__.append ('diag_matrix')

def zero_matrix(r,c,init=0):
    """Construct a r x c matrix with all elements equal to init."""
    if type (r) == type (c) == types.IntType and r>0 and c>0:
        return ex(_ginac.matrix0(r,c,ex(init).ex))
    raise ValueError,'zero_matrix() first two arguments must be positive ints'

__all__.append ('zero_matrix')

def flat_matrix(r,c,l):
    """Construct a r x c matrix from a flat sequence l."""
    if type (r) == type (c) == types.IntType and r>0 and c>0 and len(l)==r*c:
        return ex(_ginac.matrix(r,c,l))
    raise ValueError,'flat_matrix() first two arguments must be positive ints such that their products equal to the length of the 3rd argument'

__all__.append ('flat_matrix')

class wildcard(ex,arith_ops,relat_ops):
    """Wildcard for subs(), match(), has(), and find() methods.
    where label is positive int.
    def __init__ (self, label):
        self.ex = _ginac.wildcard(label)
    def get_label(self):
        """Get label."""
        return self.ex.wildcard_get_label()

ex._derived_classes['wildcard'] = wildcard

class pseries(ex,arith_ops,degree_mths,functional_mths):
    """Extended truncated power series (positive and negative
    integer powers).
      use series(var==point,order) method instead of pseries.
    def __init__(self):
        """Use series(var==point,order) method instead of pseries."""
        raise RuntimeError,'use series(var==point,order) method to construct power series'
    def is_zero (self):
        return self.ex.pseries_is_zero ()
    def get_var(self):
        """Get the expansion variable."""
        return ex(self.ex.pseries_get_var())
    def get_point(self):
        """Get the expansion point."""
        return ex(self.ex.pseries_get_point())
    def get_order(self):
        """Get the expansion order."""
    def get_lorder(self):
        """Get the expansion lowest order."""
        return self.ex.ldegree(self.ex.pseries_get_var())
    def get_compatible(self,other,deg=None):
        """Return self compatible power series of other"""
        if isinstance(other,pseries):
            if self.is_compatible_to(other):
                return other
            raise ValueError,'not compatible power series'
        if deg is None:
            deg = self.get_order()
        return other.series(self.get_var()==self.get_point(),deg)
    def is_compatible_to(self,other):
        """Check whether series is compatible to another series
        (whether expansion variable and point are the same)."""
        if isinstance (other,pseries):
            return self.ex.pseries_is_compatible_to(other.ex)
        return 0
    def is_terminating(self):
        """Returns true if there is no order term."""
        return self.ex.pseries_is_terminating()
    def convert_to_poly(self, no_order=0):
        """Convert the pseries object to an ordinary polynomial.
        no_order - flag for discarding higher order terms.
        return ex(self.ex.pseries_convert_to_poly(no_order))
    def shift_exponents(self,deg):
        """Return a new pseries object with the powers shifted by deg."""
        return ex(self.ex.pseries_shift_exponents(deg))

    def __pow__(self,other,deg=None):
        if isinstance(other,numeric) and other.is_integer():
            if deg is None:
                i = int(other)
                deg = abs(self.get_order()*i)
                ldeg = abs(self.get_lorder()*i)
                deg = max(deg,ldeg)+2
            return self.get_compatible(ex(_ginac.power(self.ex,other.ex)),deg)
        return ex(_ginac.power(self.ex,other.ex))

ex._derived_classes['pseries'] = pseries

class function(ex,arith_ops,relat_ops,functional_mths):
    def __init__(self, ser, *args, **kws):
        try: hold = kws['hold']
        except KeyError: hold=0
        self.ex = _ginac.function(ser,map(ex,args),hold)
##     def __repr__(self):
##         return '%s(%s)'%(self.ex.function_get_name(),
##                          ', '.join(map(repr,self.to_list())))
    def get_name(self):
        """Get function name."""
        return self.ex.function_get_name()
    def get_serial(self):
        """Get function serial number."""
        return self.ex.function_get_serial()
    def is_function(self,name=None,noargs=None):
        """Check if function has name and number of arguments."""
        if name is None and noargs is None:
            return 1
        return self.get_serial()==_ginac.function_find(name,noargs)

ex._derived_classes['function'] = function

class fderivative(function):
    def __init__(self,serial,params,*args,**kws):
        try: hold = kws['hold']
        except KeyError: hold=0
        self.ex = _ginac.fderivative(serial,params,map(ex,args),hold)
    def get_params(self):
        d = self.ex.extract_archived(0)
        return map(int,d['param'])
##     def __repr__(self):
##         return 'D(%s,%s)'%(self.get_params(),function.__repr__(self))
##     def to_string(self, level=0):
##         return 'D%s(%s)(%s)'%(self.get_params(),
##                             self.ex.function_get_name(),
##                             ', '.join([e.to_string(10) for e in self.to_list()]))

ex._derived_classes['fderivative'] = fderivative

def Diff(params,func):
    if not isinstance(func,function):
        raise TypeError,'Diff() expected function second argument.'
    return fderivative(func.get_serial(),params,*func.to_list())

class DiffF:
    def __init__(self,params,func):
        assert type(params) is types.ListType
        assert type(func) is types.FunctionType
        self.params = params
        self.func = func
    def __call__(self,*args):
        print self.params,self.func
        return Diff(self.params,self.func(*args))
    def __repr__(self):
        return 'DiffF(%s, %s)'%(repr(self.params), self.func.func_name)
    def __str__(self):
        return 'D%s(%s)'%(self.params, self.func.func_name)

class D:
    def __init__(self, params = []):
        self.params = params
    def __call__(self,*args):
        if len(args)==1:
            return DiffF(self.params,args[0])
        return Diff(*args)
    def __getitem__(self,item):
        if type(item) is types.TupleType:
            item = list(item)
            item = [item]
        return self.__class__(item)
    def __repr__(self):
        return 'D(%s)'%(repr(self.params))
    def __str__(self):
        return 'D%s'%(self.params)
D = D()

__all__ += ['D']

class Function:
    """Base class for user-defined functions."""
    # Do not redefine the following methods in derived classes.
    defined = 0
    cache = []
    def __init__(self):
        if self.defined:
            raise RuntimeError,'class %s is already used'%(self.__class__.__name__)
        self.defined = 1
        self.fixednparams = 1
    def __call__(self,*args,**kws):
        if self.fixednparams and len(args)!=self.nparams:
            raise TypeError,'function %s takes exactly %s arguments but got %s'%(self.get_name(),self.nparams,len(args))
        try: hold=kws['hold']
        except KeyError: hold=0
        return self.func(*args,**{'hold':hold})
    def _check_eval(self):
        if not hasattr(self,'eval'): return None
        if self.eval.func_code.co_argcount!=len(self.eval.func_code.co_varnames):
            return None
        return self.eval.func_code.co_argcount-1
    def _check_evalf(self):
        if not hasattr(self,'evalf'): return None
        if self.evalf.func_code.co_argcount!=len(self.evalf.func_code.co_varnames):
            return None
        return self.evalf.func_code.co_argcount-1
    def _check_derivative(self):
        if not hasattr(self,'derivative'): return None
        if self.derivative.func_code.co_argcount!=len(self.derivative.func_code.co_varnames):
            return None
        return self.derivative.func_code.co_argcount-2
    def _check_series(self):
        if not hasattr(self,'series'): return None
        if self.series.func_code.co_argcount!=len(self.series.func_code.co_varnames):
            return None
        return self.series.func_code.co_argcount-4
    def get_nparams(self):
        if hasattr(self,'nparams'):
            return self.nparams
        p = None
        for f in ['eval','evalf','derivative','series']:
            n = getattr(self,'_check_'+f)()
            if n is None: continue
            if p is None:
                p = n
                if p!=n:
                    raise TypeError,'Function.get_nparams(): %s must define %s arguments but got %s'%(f,p,n)
        if p is None:
            p = 0
            self.fixednparams = 0
        self.nparams = p
        return p

    # Users may redefine only the following methods
    def get_name(self):
        return self.__class__.__name__
    def get_TeX_name(self):
        return '\\mbox{%s}'%(self.get_name())
    # and optionally define methods:
    #   eval()
    #   evalf()
    #   derivative(,diff_param)
    #   series(,rel,order,options)

def build_function(cls):
    if type(cls) is types.StringType:
        exec 'class %s(Function): pass\ncls = %s'%(cls,cls)
        return build_function(cls)
    if Function not in cls.__bases__:
        raise TypeError,'build_function() argument must be sub-class of Function'
    func = cls()
    name = func.get_name()
    texname = func.get_TeX_name()
    nparams = func.get_nparams()
    eval_f = None
    evalf_f = None
    derivative_f = None
    series_f = None
    id = len(Function.cache)
    if hasattr(func,'eval'):
        exec '''def eval_f(*x):
    import ginac
    return ginac.Function.cache[%d].eval(*map(ginac.ex,x))
    if hasattr(func,'evalf'):
        exec '''def evalf_f(*x):
    import ginac
    return ginac.Function.cache[%d].evalf(*map(ginac.ex,x))
    if hasattr(func,'derivative'):
        exec '''def derivative_f(*x):
    import ginac
    return ginac.Function.cache[%d].derivative(*(map(ginac.ex,x[:-1])+[x[-1]]))
    if hasattr(func,'series'):
        exec '''def series_f(*x):
    import ginac
    return ginac.Function.cache[%d].series(*(map(ginac.ex,x[:-2])+[x[-2],x[-1]]))
    serial = _ginac.build_function(name,texname,nparams,
    exec '''def %s(*x,**k):
    try: hold = k["hold"]
    except: hold = 0
    return function(%d,*x,**{"hold":hold})
    func.func = fun
    return fun

__all__ += ['Function']

#TODO: be more explicit, put it in separate module, say, ginac.functions
if _ginac_ok:
    for n in ['abs','csgn','sin','cos','tan','exp','log',
        if n in ['abs']: continue
        s = _ginac.function_find(n,1)
        exec 'def %s(x): return function(%d,x)\n'%(n,s)
    for n in ['eta','atan2','beta','binomial']:
        s = _ginac.function_find(n,2)
        exec 'def %s(x,y): return function(%d,x,y)\n'%(n,s)
    s = _ginac.function_find('abs',1)
    exec 'def %s(x): return function(%d,x)\n'%('symbolic_abs',s)

    del n,s

def Digits(prec=None):
    """Get/set number of decimal digits."""
    prev = _ginac.get_Digits()
    if prec is None:
        return prev
    if type(prec) is types.StringType and prec=='default':
    elif type(prec) is types.IntType:
        raise TypeError,'Digits() expected None|int|"default" but got '+`type(prec)`
    return prev

if _ginac_ok:
    I = numeric('I')
    E = constant('E',lambda : exp(1),'e')
    Pi = ex(_ginac.Pi)
    Euler = ex(_ginac.Euler)
    Catalan = ex(_ginac.Catalan)
    _ex1_2 = numeric(1,2).ex
    _ex1 = numeric(1).ex
    _ex_1 = numeric(-1).ex
    _ex0 = numeric(0).ex

newfunction = build_function

__all__ += ['I','Pi','Euler','Catalan','E',
__all__ += [k for k in ex._derived_classes.keys()
            if k not in ['function','fderivative']]

# Functions
def sqrfree(a,l=[]):
    Compute square-free factorization of multivariate polynomial a in
    Q[X].  l is lst of variables X to factor in, may be left empty for
    if isinstance(l,lst):
    elif type(l) is types.ListType:
        l = lst(*l)
        l = lst(l)
    res = ex(_ginac.sqrfree(ex(a).ex,l.ex))
    if isinstance(res,add):
        # handle cases like expand((a+b)**2 + e*(c+d)**2).
        nres1 = 0
        nres2 = 0
        for t in res.to_list():
            nt = sqrfree(t,l)
            if nt.is_equal(t):
                nres1 += t
                nres2 += nt
        if nres2 and nres1:
            nres = sqrfree(nres1,l) + nres2
            nres = nres1 + nres2
        return nres
    return res

def decomp_rational(a,x):
    Decompose rational function a(x)=N(x)/D(x) into P(x)+n(x)/D(x)
    with degree(n, x) < degree(D, x).
    if isinstance (x,symbol):
        return ex(_ginac.decomp_rational(ex(a).ex,x.ex))
    raise TypeError,'decomp_rational() function 2nd argument must be symbol'

def expand(e):
    return ex(e).expand()

if _ginac_ok:
    import atexit
    del atexit

del _ginac_ok

FreeBSD-CVSweb <>