from polybori.nf import *
import polybori.aes as aesmod

from polybori.PyPolyBoRi import *
from polybori.ll import eliminate, ll_encode
from time import time
from copy import copy
from itertools import chain
from inspect import getargspec
from statistics import used_vars, used_vars_set
from heuristics import dense_system,gauss_on_linear
from itertools import chain
from polybori.interpolate import lex_groebner_basis_for_polynomial_via_variety
def owns_one_constant(I):
    """Determines whether I contains the constant one polynomial."""
    for p in I:
        if p.isOne():
            return True
    return False


#strategy brainstorming

#count number of unbound vars to see, if ll is a good choice

#-direct computation
#-ll first
#-dp_asc first (don't need walk, as result is usually quite nice)
#-conjunction and factoring first
#-inversion

#LA at few variables and no ll
#LA at "dense" systems


#@TODO: implement this to work with *args, **args for wrapped function

def want_interpolation_gb(G):
    if get_order_code()!=OrderCode.lp:
        return False
    if len(G)!=1:
        return False
    p=Polynomial(G[0])
    if p.lmDeg()<=1:
        return False
    if p.set().nNodes()>1000:
        return False
    return True

def ll_is_good(I):
    lex_lead=set()
    for p in I:
        m=p.lexLead()
        if m.deg()==1:
            lex_lead.add(iter(m.variables()).next().index())
    if len(lex_lead)>=0.8*len(I):
        uv=len(used_vars_set(I))
        if len(lex_lead)>0.9*uv:
            return True
    return False
    
def ll_heuristic(d):
    d_orig=d
    d=copy(d)

    I=d["I"]

    if (not "llfirstonthefly" in d) and (not "llfirst" in d) and ll_is_good(I):
        d["llfirst"]=True
    return d



def change_order_heuristic(d):
    d_orig=d
    d=copy(d)
    I=d["I"]
    switch_table={OrderCode.lp:OrderCode.dlex,OrderCode.dlex:OrderCode.dp_asc}
    if not "other_ordering_first" in d:
        #TODO after ll situation might look much different, so heuristic is on wrong place
        code=get_order_code()
        if code in switch_table:
            max_non_linear=len(I)/2
            non_linear=0
            if code==OrderCode.lp:
                for p in I:
                    if p.lead().deg()>1:
                        non_linear=non_linear+1
                        if non_linear>max_non_linear:
                            break
            if (non_linear>max_non_linear) or (code!=OrderCode.lp):
                other_ordering_opts=copy(d_orig)
                other_ordering_opts["switch_to"]=switch_table[code]
                d["other_ordering_first"]=other_ordering_opts
    return d


def interpolation_gb_heuristic(d):
    d=copy(d)
    I=d["I"]
    if not d.get("other_ordering_opts",False) and want_interpolation_gb(I):
        d["interpolation_gb"]=True
        d["other_ordering_first"]=False
    return d
def linear_algebra_heuristic(d):
    d_orig=d
    d=copy(d)
    I=d["I"]
    def want_la():
        n_used_vars=None
        bound=None
        if have_degree_order():
            new_bound=200
            n_used_vars=len(used_vars_set(I,bound=new_bound))
            if n_used_vars<new_bound:
                return True
            bound=new_bound
        if dense_system(I):
            new_bound=100
            if not (bound and new_bound<bound):
                n_used_vars=len(used_vars_set(I,bound=new_bound))
                bound=new_bound
            if n_used_vars<bound:
                return True
        return False
    if not (("faugere" in d) or ("noro" in d)):
        if want_la():

            d["faugere"]=True
            if not "red_tail" in d:
                d["red_tail"]=False
            if not "selection_size" in d:
                d["selection_size"]=10000
            if not ("ll" in d):
                d["ll"]=True

    return d

def trivial_heuristic(d):   
    return d
class HeuristicalFunction(object):
    def __call__(self,*args,**kwds):
        complete_dict=copy(kwds)
        heuristic=True
        try:
            heuristic=complete_dict["heuristic"]
            #del complete_dict["heuristic"]
        except KeyError:
            pass


        for (k,v) in zip(self.argnames,args):
            complete_dict[k]=v 
        if heuristic:
            complete_dict=self.heuristicFunction(complete_dict)
        return self.f(**complete_dict)
    def __init__(self,f,heuristic_function):
        
        (self.argnames,self.varargs,self.varopts,self.defaults)=getargspec(f)
        if hasattr(f,"options"):
            self.options=f.options
        else:
            self.options=dict(zip(self.argnames[-len(self.defaults):],self.defaults))
        self.heuristicFunction=heuristic_function
        self.f=f
        self.__doc__=f.__doc__#
        
def with_heuristic(heuristic_function):
    def make_wrapper(f):
        wrapped=HeuristicalFunction(f,heuristic_function)
        wrapped.__name__=f.__name__
        return wrapped
    return make_wrapper
def clean_polys(I):
    zero=Polynomial(0)
    I=[Polynomial(p) for p in I if p!=zero]
    return I
def clean_polys_pre(I):
    return (clean_polys(I),None) 
def gb_with_pre_post_option(option,pre=None,post=None,if_not_option=tuple(),default=False, pass_option_set=False):
    def make_wrapper(f):
        def wrapper(I,**kwds):
            prot=False
            if "prot" in kwds:
                prot=kwds["prot"]
            for o in if_not_option:
                if (o in kwds and kwds[o]) or (o not in kwds and groebner_basis.options[o]):
                    option_set=False
            if not "option_set" in locals():
                if option in kwds:
                    
                    option_set=kwds[option]
                else:
                    option_set=default
            kwds=dict(((o,kwds[o]) for o in kwds if o!=option))
            state=None
            if option_set:
               if pre:
                   if prot:
                       print "preprocessing for option:", option
                   if not pass_option_set:
                       (I,state)=pre(I)
                   else:
                       (I,state)=pre(I,option_set)
               
            res=f(I,**kwds)
            if option_set:
                if post:
                    if prot:
                        print "postprocessing for option:", option
                    res=post(res,state)
                

            return res
        wrapper.__name__=f.__name__
        wrapper.__doc__=f.__doc__
        if hasattr(f,"options"):
            wrapper.options=copy(f.options)
        else:
            (argnames,varargs,varopts,defaults)=getargspec(f)
            wrapper.options=dict(zip(argnames[-len(defaults):],defaults))

        wrapper.options[option]=default
        return wrapper
    return make_wrapper
def redsb_post(I,state):
    if I==[]:
        return []
    else:
        return I.minimalizeAndTailReduce()
def minsb_post(I,state):
    if I==[]:
        return []
    else:
        return I.minimalize()
def invert_all(I):
    return [p.mapEveryXToXPlusOne() for p in I]
def invert_all_pre(I):
    return (invert_all(I),None)
def invert_all_post(I,state):
    return invert_all(I)
    
def llfirst_pre(I):
    (eliminated,llnf, I)=eliminate(I,on_the_fly=False)
    return (I,eliminated)

def ll_constants_pre(I):
    I_new=[]
    ll=[]
    leads=set()
    for p in I:
        if p.lexLmDeg()==1:
            l=p.lead()
            if not (l in leads) and len(p)<=2:
                tail=p+l
                if tail.deg()<=0:
                    ll.append(p)
                    leads.add(l)
                    continue
        I_new.append(p)
    encoded=ll_encode(ll)
    reduced=[]
    for p in I_new:
        p=ll_red_nf(p,encoded)
        if not p.isZero():
            reduced.append(p)
    #(eliminated,llnf, I)=eliminate(I,on_the_fly=False)
    
    return (reduced,ll)

def other_ordering_pre(I,options):
    ocode=get_order_code()
    assert (ocode==OrderCode.lp) or (ocode==OrderCode.dlex)
    #in parcticular it does not work for block orderings, because of the block sizes
    old_ring=global_ring()
    change_ordering(options["switch_to"])
    kwds=dict((k,options[k]) for k in options if not (k in ("other_ordering_first","switch_to","I")))
    I_orig=I
    I=groebner_basis(I,**kwds)
    for p in I:
        if p.deg()>1:
            I=list(chain(I,I_orig))
            break
    #change_ordering(ocode)
    old_ring.set()
    return (I,None)
def llfirstonthefly_pre(I):
    (eliminated,llnf, I)=eliminate(I,on_the_fly=True)
    return (I,eliminated)
def llfirst_post(I,eliminated):
    for p in I:
        if p.isOne():
            return [p]
    else:
        if len(eliminated)>0:
            I=list(chain(I,eliminated))
            #redsb just for safety, as don't know how option is set
            I=groebner_basis(I,llfirst=False,llfirstonthefly=False,ll_constants=False,other_ordering_first=False,redsb=True)
    return I


def ll_constants_post(I,eliminated):
    for p in I:
        if p.isOne():
            return [p]
    else:
        if len(eliminated)>0:
            I=list(chain(I,eliminated))
            #redsb just for safety, as don't know how option is set
        return I

def result_to_list_post(I,state):
    return list(I)
def fix_deg_bound_post(I,state):
    if isinstance(I,GroebnerStrategy):
        return I.allGenerators()
    else:
        return I


@gb_with_pre_post_option("clean_arguments",pre=clean_polys_pre,default=True)
@with_heuristic(ll_heuristic)

@gb_with_pre_post_option("result_to_list",post=result_to_list_post,default=True)
@with_heuristic(interpolation_gb_heuristic)
@gb_with_pre_post_option("invert",pre=invert_all_pre,post=invert_all_post,default=False)
@gb_with_pre_post_option("ll_constants",if_not_option=["llfirstonthefly","llfirst"],pre=ll_constants_pre,post=ll_constants_post,default=True)
@gb_with_pre_post_option("llfirst",if_not_option=["llfirstonthefly"],pre=llfirst_pre,post=llfirst_post,default=False)
@gb_with_pre_post_option("llfirstonthefly",pre=llfirstonthefly_pre,post=llfirst_post,default=False)

@with_heuristic(change_order_heuristic)
@gb_with_pre_post_option("other_ordering_first",if_not_option=["interpolation_gb"],pre=other_ordering_pre,default=False,pass_option_set=True)
@with_heuristic(linear_algebra_heuristic)
@gb_with_pre_post_option("fix_deg_bound",if_not_option=["interpolation_gb"], post=fix_deg_bound_post,default=True)
@gb_with_pre_post_option("minsb",post=minsb_post,if_not_option=["redsb","deg_bound","interpolation_gb"],default=True)
@gb_with_pre_post_option("redsb",post=redsb_post,if_not_option=["deg_bound","interpolation_gb"],default=True)

def groebner_basis(I, faugere=False,
       preprocess_only=False, selection_size= 1000,
       full_prot= False, recursion= False,
       prot= False, step_factor= 1,
       deg_bound=False, lazy= True, ll= False,
       max_growth= 2.0, exchange= True,
       matrix_prefix= "matrix", red_tail= True,
       implementation="Python", aes= False,
       llfirst= False, noro= False, implications= False,
       draw_matrices= False, llfirstonthefly= False,
       linear_algebra_in_last_block=True, gauss_on_linear_first=True,heuristic=True,unique_ideal_generator=False, interpolation_gb=False):
    """Computes a Groebner basis of a given ideal I, w.r.t options."""
    if interpolation_gb:
        if len(I)!=1 or get_order_code()!=OrderCode.lp:
            raise ValueError
        return lex_groebner_basis_for_polynomial_via_variety(I[0])
    if deg_bound is False:
        deg_bound=100000000L
    zero=Polynomial(0)
    I=[Polynomial(p) for p in I if p!=zero]
    if unique_ideal_generator:
        prod=1
        for p in I:
            prod=(p+1)*prod
        I=[prod+Polynomial(1)]
    if gauss_on_linear_first:
        I=gauss_on_linear(I)
    import nf
    if full_prot:
        prot=True
    
    nf.print_matrices=draw_matrices
    nf.matrix_prefix=matrix_prefix
    
    if implementation=="Python":
        implementation=symmGB_F2_python
    else:
        implementation=symmGB_F2_C
    
    if aes:
        pt=time()
        I=aesmod.preprocess(I, prot=prot)
        pt2=time()
        if prot:
          print "preprocessing time", pt2-pt

    if preprocess_only:
      for p in I:
        print p
      del p
      del I
      import sys
      sys.exit(0)    
    I=implementation(I, optRedTail=red_tail,\
        max_growth=max_growth, step_factor=step_factor,
        implications=implications,prot=prot,
        full_prot=full_prot,deg_bound=deg_bound,
        selection_size=selection_size, optLazy=lazy, 
        optExchange=exchange, optAllowRecursion=recursion,
        use_faugere=faugere,
        use_noro=noro,ll=ll,
        optLinearAlgebraInLastBlock=linear_algebra_in_last_block)
    return I

groebner_basis.__doc__=groebner_basis.__doc__+"\nOptions are:\n"+"\n".join((k+"  :  "+repr(groebner_basis.options[k]) for k in groebner_basis.options))+"\nTurn off heuristic by setting heuristic=False"
