#Creates the .hwv file for a single specific tableau
#The common usage is to not call this manually, but to use "tableau2polynomial_directory.py" to treat a whole directory containing files with tableaux.


#If you really want to use it, here is an example usage:
#python3 tableau2polynomial.py "\"(7,4,1)\"" 3 4 "\"(1,1,1,1,2,3,3,2,2,2,3,3)\""

#parameters:
#(1) The partition, i.e., the shape of the Young diagram
#(2) The outer plethysm parameter
#(3) The inner plethysm parameter
#(4) The entries of the tableau, rowwise.


#This python3 program uses the file "blueprint.c" and calls "gcc".
#It creates a file "hwv.hwv" as an output, provided there is enough memory.
#The memory usage is quite high. If there is not enough memory, then a file
#"MEMplO[outerdegree]I[innerdegree]X[partition]X[tableau].txt"
#is created. Example: MEMplO4I4X14_2X1_1_1_1_2_2_2_2_3_3_3_3_4_4_4_4.txt
#If run on a cluster computer, then the existence of these files point precisely to the cases where there was not enough memory.
#A common reason is that too many other memory intense processes are running at the same time. Close those processes and try again.




import itertools
import functools
import math
import os
import subprocess
import sys


if sys.version_info.major <= 2:
  print("Python must be at least version 3.")
  exit()
  

outofmemfolder = "."

p = eval(eval(sys.argv[1]))
OUTERDEGREE = eval(sys.argv[2])
INNERDEGREE = eval(sys.argv[3])
rowwisecontent = eval(eval(sys.argv[4]))






target = []
for row in range(len(p)):
  for col in range(p[row]):
    target = target + [row]

varhashes =  [321017603, 409032069, 179149675, 385763043, 386210053, 178229077, 514005093, 137972357, 108903411, 477157075, 386210057, 108903449, 386210057, 409032079, 321017611]

def hashvariable(variable):
  return varhashes[variable]

def hashmonomial(monomial):
  ret = 0
  for var in monomial:
    twistedvar = 0
    exponent = monomialhashingfunction
    summand=1
    for i in range(exponent):
      summand *= hashvariable(var)
      summand%=536870909
    twistedvar += summand
    twistedvar%=536870909
    ret+=twistedvar
    ret%=536870909
    ret%=536870909
  return ret



def hashmonomialmonomial(monomialmonomial):
  ret = 0
  for mon in monomialmonomial:
    twistedmon = 0
    exponent = monomialmonomialhashingfunction
    summand=hashmonomial(mon)
    for i in range(exponent):
      summand *= summand
      summand%=536870909
    twistedmon += summand
    twistedmon%=536870909
    ret+=twistedmon
    ret%=536870909
  return ret

def hashmonomialmonomial2(monomialmonomial):
  monomialmonomialhashingfunction2 = monomialmonomialhashingfunction+1
  ret = 0
  for mon in monomialmonomial:
    twistedmon = 0
    exponent = monomialmonomialhashingfunction2
    summand=hashmonomial(mon)
    for i in range(exponent):
      summand *= summand
      summand%=536870909
    twistedmon += summand
    twistedmon%=536870909
    ret+=twistedmon
    ret%=536870909
  return ret

hashcollisions = set()

def findhashfunctions():
  c=0
  hashes = set()
  for v in range(len(p)):
    hashes = hashes.union({hashvariable(v)})
    c=c+1
  if (c!=len(hashes)):
    return false
  global monomialhashingfunction
  for monomialhashingfunction in itertools.count(1):
    c=0
    hashes = set()
    for m in itertools.combinations_with_replacement(range(len(p)), INNERDEGREE):
      hashes = hashes.union({hashmonomial(m)})
      c=c+1
    if (c==len(hashes)):
      break
    #else:
    #  print(len(hashes), "<", c, monomialhashingfunction)
  global monomialmonomialhashingfunction
  global monomialmonomialsofcorrectweight
  monomialmonomialsofcorrectweight = set()
  #print("Looping monomialmonomials ...")
  for t in itertools.combinations_with_replacement(itertools.combinations_with_replacement(range(len(p)), INNERDEGREE),OUTERDEGREE):
    s = sorted(functools.reduce(lambda x,y: x+y, t))
    if s==target:
      monomialmonomialsofcorrectweight = monomialmonomialsofcorrectweight.union({(t)})
  #print("Done looping monomialmonomials.")
  for monomialmonomialhashingfunction in itertools.count(8):
    hashes = set()
    global hashcollisions
    hashcollisions = set()
    for t in monomialmonomialsofcorrectweight:
      hashmonmon = hashmonomialmonomial(t)
      if hashmonmon in hashes:
        hashcollisions = hashcollisions.union({(hashmonmon)})
      hashes = hashes.union({(hashmonmon)})
    for t in monomialmonomialsofcorrectweight:
      hashmonmon = hashmonomialmonomial(t)
      if hashmonmon in hashcollisions:
        hashmonmon2 = hashmonomialmonomial2(t)
        hashes = hashes.union({(hashmonmon2)})
    if (len(monomialmonomialsofcorrectweight)==len(hashes)-len(hashcollisions)):
      #print("monomials in monomials hashed successfully: ",len(monomialmonomialsofcorrectweight), monomialmonomialhashingfunction)
      global maxhash
      maxhash = sorted(hashes)[len(hashes)-1]
      break
    else:
      print(len(hashes), "<", len(monomialmonomialsofcorrectweight), " coll=" , len(hashcollisions), monomialmonomialhashingfunction)


def univpolyinvariable(exponent,variable):
  ret = "zero64"
  summand = "("+variable+")"
  for i in range(exponent-1):
    summand = "(("+summand+"*"+variable+")%prime)"
  ret = "("+ret+"+"+summand+")%prime";
  return ret

def univpolyinvariableREPEATEDSQUAREPREPARE(exponent,variable):
  ret = "uint64_t ___"+variable+"="+variable+";"
  for i in range(exponent):
    ret += "___"+variable+" = ___"+variable+"*___"+variable+";"
    ret += "___"+variable+"%=prime;";
  return ret

def univpolyinvariableREPEATEDSQUARE(exponents,variable):
  return "___"+variable



def generatetempcfile():
  blueprintfile = open("blueprint.c", "r")
  tempcfile = open("temp.c", "w")
  monomialmonomialhashingfunction2 = monomialmonomialhashingfunction+1
  for blueprintline in blueprintfile:
    blueprintline=blueprintline.replace("##NUMBEROFROWS",str(len(p)))
    blueprintline=blueprintline.replace("##NUMBEROFCOLS",str(p[0]))
    blueprintline=blueprintline.replace("##OUTERDEGREE",str(OUTERDEGREE))
    blueprintline=blueprintline.replace("##INNERDEGREE",str(INNERDEGREE))
    blueprintline=blueprintline.replace("##TWISTEDVAR",univpolyinvariable(monomialhashingfunction,"var"))
    blueprintline=blueprintline.replace("##TWISTEDVA2",univpolyinvariable(monomialhashingfunction,"var"))
    blueprintline=blueprintline.replace("##NEWENTRY0",univpolyinvariable(monomialhashingfunction,"newentry0"))
    blueprintline=blueprintline.replace("##NEWENTRY1",univpolyinvariable(monomialhashingfunction,"newentry1"))
    blueprintline=blueprintline.replace("##NEWENTR20",univpolyinvariable(monomialhashingfunction,"newentry0"))
    blueprintline=blueprintline.replace("##NEWENTR21",univpolyinvariable(monomialhashingfunction,"newentry1"))
    blueprintline=blueprintline.replace("##PREPARETWISTEDBLOCKHASH",univpolyinvariableREPEATEDSQUAREPREPARE(monomialmonomialhashingfunction,"blockhash"))
    blueprintline=blueprintline.replace("##TWISTEDBLOCKHASH",univpolyinvariableREPEATEDSQUARE(monomialmonomialhashingfunction,"blockhash"))
    blueprintline=blueprintline.replace("##PREPARETWISTEDBLOCKHAS2",univpolyinvariableREPEATEDSQUAREPREPARE(monomialmonomialhashingfunction2,"blockhash"))
    blueprintline=blueprintline.replace("##TWISTEDBLOCKHAS2",univpolyinvariableREPEATEDSQUARE(monomialmonomialhashingfunction2,"blockhash"))
    blueprintline=blueprintline.replace("##PREPAREBLOCKHASH0",univpolyinvariableREPEATEDSQUAREPREPARE(monomialmonomialhashingfunction,"blockhash0"))
    blueprintline=blueprintline.replace("##BLOCKHASH0",univpolyinvariableREPEATEDSQUARE(monomialmonomialhashingfunction,"blockhash0"))
    blueprintline=blueprintline.replace("##PREPAREBLOCKHASH1",univpolyinvariableREPEATEDSQUAREPREPARE(monomialmonomialhashingfunction,"blockhash1"))
    blueprintline=blueprintline.replace("##BLOCKHASH1",univpolyinvariableREPEATEDSQUARE(monomialmonomialhashingfunction,"blockhash1"))
    blueprintline=blueprintline.replace("##PREPAREBLOCKHAS20",univpolyinvariableREPEATEDSQUAREPREPARE(monomialmonomialhashingfunction2,"blockhash0"))
    blueprintline=blueprintline.replace("##BLOCKHAS20",univpolyinvariableREPEATEDSQUARE(monomialmonomialhashingfunction2,"blockhash0"))
    blueprintline=blueprintline.replace("##PREPAREBLOCKHAS21",univpolyinvariableREPEATEDSQUAREPREPARE(monomialmonomialhashingfunction2,"blockhash1"))
    blueprintline=blueprintline.replace("##BLOCKHAS21",univpolyinvariableREPEATEDSQUARE(monomialmonomialhashingfunction2,"blockhash1"))
    blueprintline=blueprintline.replace("##NOTENOUGHMEMFILENAME","\""+outofmemfolder+"/MEMplO"+str(OUTERDEGREE)+"I"+str(INNERDEGREE)+"X"+str(p).replace(" ","").replace("(","").replace(")","").replace(",","_")+"X"+str(rowwisecontent).replace(" ","").replace("(","").replace(")","").replace(",","_")+".txt\");")
    blueprintline=blueprintline.replace("##ROWLENGTHS",str(p).replace("(","").replace(")",""))
    collengths = [len(list(filter(lambda i:p[i]>=k,range(len(p))))) for k in range(1,p[0]+1)]
    blueprintline=blueprintline.replace("##COLLENGTHS",str(collengths).replace("[","").replace("]",""))
    blueprintline=blueprintline.replace("##PRODUCTOFCOLUMNLENGTHFACTORIALS",'*'.join(map(lambda x:str(x)+"ULL", map(math.factorial,collengths))))
#    blueprintline=blueprintline.replace("##PRODUCTOFCOLUMNLENGTHFACTORIALS",'*'.join(map(str, map(math.factorial,collengths))))
    if collengths[0]>8:
      print("A new Gray code for a larger symmetric group is needed: swaplength?row0,swaplength?row1")
      exit()
    blueprintline=blueprintline.replace("##TABLEAU",",\n    ".join([','.join([str(k)]*p[k] + ['X']*(p[0]-p[k])) for k in range(len(p))]))
    if "##BLOCKSTRUCTURE" in blueprintline:
      blueprintline="    "
      for row in range(len(p)):
        if row==0:
          alreadyprinted = 0
        else:
          alreadyprinted = sum(p[0:row])
        blueprintline += ','.join(map(lambda i:str(i)+"-1",rowwisecontent[alreadyprinted:(alreadyprinted+p[row])]+(0,)*(p[0]-p[row]))) +",\n    "
    if "##INTCOLPERMFORWARD0BACKWARD1IS0" in blueprintline:
      blueprintline="";
      for col in range(p[0]):
        if collengths[col]>1:
          blueprintline += "  int col"+str(col)+"permforward0backward1 = 0;\n"
          blueprintline += "  int col"+str(col)+"perm = 0;\n"
    if "##COLLISIONS" in blueprintline:
      blueprintline="";
      for collision in hashcollisions:
        blueprintline += "  hwv["+str(collision)+"]=COLLISIONMARKER;\n"
    if "##LONGCOMPUTATION" in blueprintline:
      blueprintline="";
      for col in range(p[0]):
        if collengths[col]>1:
          blueprintline += "    if (col"+str(col)+"perm<"+str(math.factorial(collengths[col]))+"-1 && col"+str(col)+"permforward0backward1==0) {\n"
          blueprintline += "      col"+str(col)+"perm++;SWAPINTABLEAU(tableau,swaplength"+str(collengths[col])+"row0[col"+str(col)+"perm],"+str(col)+",swaplength"+str(collengths[col])+"row1[col"+str(col)+"perm],"+str(col)+");\n";
          blueprintline += "      swappedcol="+str(col)+";swappedrow0=swaplength"+str(collengths[col])+"row0[col"+str(col)+"perm];swappedrow1=swaplength"+str(collengths[col])+"row1[col"+str(col)+"perm];\n";
          blueprintline += "    } else if (col"+str(col)+"perm>0 && col"+str(col)+"permforward0backward1==1) {\n";
          blueprintline += "      swappedcol="+str(col)+";swappedrow0=swaplength"+str(collengths[col])+"row0[col"+str(col)+"perm];swappedrow1=swaplength"+str(collengths[col])+"row1[col"+str(col)+"perm];\n";
          blueprintline += "      SWAPINTABLEAU(tableau,swaplength"+str(collengths[col])+"row0[col"+str(col)+"perm],"+str(col)+",swaplength"+str(collengths[col])+"row1[col"+str(col)+"perm],"+str(col)+");col"+str(col)+"perm--;\n";
          blueprintline += "    } else {\n";
          blueprintline += "      col"+str(col)+"permforward0backward1=!col"+str(col)+"permforward0backward1;\n";
      for col in range(p[0]):
        if collengths[col]>1:
          blueprintline += "    }";
    tempcfile.write(blueprintline)
  blueprintfile.close()
  tempcfile.close()

def multinomialcoefficientofcontent(monomial):
  maxentry = max(monomial)
  numberofentries = len(monomial)
  ret = math.factorial(numberofentries)
  for i in range(maxentry+1):
    ret = ret // math.factorial(monomial.count(i))
  return ret

def printhwvhwvfile(monomialmonomialsofcorrectweight):
  monomialmonomial2coeffdict = dict()
  gcd = 0
  for t in monomialmonomialsofcorrectweight:
      hashm = hashmonomialmonomial(t)
      if hashm in hwv:
        value = hwv[hashm]
        if value == 536870909:
          hashm2 = hashmonomialmonomial2(t)
          if hashm2 in hwv:
            value = hwv[hashm2]
          else:
            continue
        for monomial in t:
          value = value * math.factorial(len(monomial))
          value = value // multinomialcoefficientofcontent(monomial)
        monomialmonomial2coeffdict[t]=value
        gcd = math.gcd(gcd,value)
  for t in monomialmonomial2coeffdict.keys():
    monomialmonomial2coeffdict[t] = monomialmonomial2coeffdict[t] // gcd
  
  hwvhwvfile = open("hwv.hwv", "w")
  veryfirstmonomialmonomial = True
  for monomialmonomial in sorted(list(monomialmonomial2coeffdict.keys())):
    monomialmonomialstring = ""
    setofmonomials = list(dict.fromkeys(monomialmonomial))
    printedStar = False
    for monomial in setofmonomials:
      monomialexponent = monomialmonomial.count(monomial)
      if printedStar:
        monomialmonomialstring = monomialmonomialstring + "*"
      shiftedmonomial = list(map(lambda x:x+1,monomial))
      varname = "vx"+str(shiftedmonomial).replace(" ","").replace(",","x").replace("[","").replace("]","")
      monomialmonomialstring = monomialmonomialstring + varname
      if monomialexponent > 1:
        monomialmonomialstring = monomialmonomialstring + "^" + str(monomialexponent)
      printedStar = True
    if veryfirstmonomialmonomial:
      veryfirstmonomialmonomial=False
      if monomialmonomial2coeffdict[monomialmonomial] > 0:
        hwvhwvfile.write(str(monomialmonomial2coeffdict[monomialmonomial]))
      if monomialmonomial2coeffdict[monomialmonomial] < 0:
        hwvhwvfile.write("-")
        hwvhwvfile.write(str(-monomialmonomial2coeffdict[monomialmonomial]))
    else:
      if monomialmonomial2coeffdict[monomialmonomial] > 0:
        hwvhwvfile.write(" + ")
        hwvhwvfile.write(str(monomialmonomial2coeffdict[monomialmonomial]))
      if monomialmonomial2coeffdict[monomialmonomial] < 0:
        hwvhwvfile.write(" - ")
        hwvhwvfile.write(str(-monomialmonomial2coeffdict[monomialmonomial]))
    hwvhwvfile.write("*")
    hwvhwvfile.write(monomialmonomialstring)


def main():
  findhashfunctions()
  generatetempcfile()
  print("gcc ... ", end='', flush=True)
  os.system("gcc -O3 temp.c -o temp.out")
  os.remove("temp.c")
  print(" ... gcc done.", flush=True)
  print("run computation ... ", end='', flush=True)
  #if the C program runs out of memory, we still have the empty temp.py file:
  emptytemppyfile = open("temp.py", "w")
  emptytemppyfile.close()
  ###
  os.system("./temp.out")
  os.remove("temp.out")
  print("... computation done.", flush=True)
  global hwv
  hwv = dict()
  temppyfile = open('temp.py', 'r')
  output = temppyfile.readlines()
  temppyfile.close()
  os.remove("temp.py")
  ###
  # if output is empty, write an "out of mem" file.
  if len(output)==0:
    outofmemfile = open(""+outofmemfolder+"/MEMplO"+str(OUTERDEGREE)+"I"+str(INNERDEGREE)+"X"+str(p).replace(" ","").replace("(","").replace(")","").replace(",","_")+"X"+str(rowwisecontent).replace(" ","").replace("(","").replace(")","").replace(",","_")+".txt", 'w')
    outofmemfile.close()
    if os.path.exists("hwv.hwv"):
      os.remove("hwv.hwv")
  else:
    for line in output:
      if not "#" in line:
        exec(line)
    printhwvhwvfile(monomialmonomialsofcorrectweight)
  ###
  


main()

