#
# Wordle in z3 and plain Python
#
# Note: I wondered how z3's String constraints would do on Wordle.
#       Well, it's not fast, so I did some plain Python program as well.
# 
# - wordle1(): using z3's String constraints. Slow
# - wordle1(): using z3's String constraints using push() pop(). Faster
# - wordle2(): using plain Python functions (but no regexes). Fast
# - wordle3(): using regexes. Fast.
#
#
# For the last 'tacit' problem we can see that the regex approach (wordle3) is
# the fastest:
#  - wordle1(): 1.89s
#  - wordle1b(): 0.36s
#  - wordle2(): 0.0014s
#  - wordle3(): 0.00066s
#
#
# All program has the same interface:
#
# *  wordle3(words,correct_positions,correct_chars,not_in_word)
#
#   - correct_positions: These characters are in correct positions
#                        Example: ".ac.t":
#                            - "a" is in 2nd position
#                            - "c" is in 3rd position
#                            - "t" is in 5th position
#   - correct_chars: These characters are in the word but in the
#                    in the correct position (if != "").
#                    Example: ["c","","a","",""]
#                    - "c" is in the target word, but not in 1st position
#                    - "a" is in the target word, but not in 3rd position
#
#   - not_in_word: These characters are not in the target word.
#                   Example: "slnedyh"
#                   - none of these characters are in the target word.
#      
# Here's an example, the last run for the 'tacit' target word
# given the earlier guesses: slant, cadet, yacht:
#
#    candidates = wordle3(candidates,".ac.t",["c","","a","",""],"slnedyh")
#
# This is used in the tests below.
#
#
# Note: All programs requires a word list with just 5 letter words,
#       which I call wordle_words.txt
#
# This z3 model was written by Hakan Kjellerstrand (hakank@gmail.com)
# See also my z3 page: http://hakank.org/z3/
#
import re, time
from z3_utils_hakank import *

#
# Frequencies (reversed) for each position in the word
# See http://hakank.org/picat/wordle.pi
# for details about this.
#
freq = ["zyjkquinovhewlrmdgfaptbcsx", # first position in words
        "zqfkgxvbsdymcwptnhulieroaj", # second ...
        "qjhzkxfwyvcbpmgdstlnrueoia",
        "xyzbwhfvpkmdguotcrilasnejq",
        "uzxbiwfcsgmpoakdnhlrtyejqv"]

word_list = "wordle_words.txt"

#
# Sort the words according to the frequency scores.
# Here are the best scored words
# ['crane', 219], ['saint', 219], ['brine', 217], ['coast', 217], ['crone', 217], ['boast', 216],
# ['briny', 216], ['cause', 216], ['crony', 216], ['paint', 215], ['poise', 215], ['slant', 215],
# ['beast', 214], ['borne', 214], ['clone', 214], ['corny', 214], ...
#
def sort_words(words,n=5):
    """
    Sort the words according to the frequencies of each position
    in the words. High is better.
    """
    m = {}
    for word in words:
        # Score according to frequencies
        score = 0
        for i in range(n):
            for s,c in enumerate(freq[i]):
                if word[i] == c:
                    score += s + i/2 # factor + i/2 seems good
        # We prefer distinct words
        if len(list(dict.fromkeys(word))) == n:
            score += 100            
        m[word] = int(score)

    word_sorted = list(sorted(m.items(),key=lambda item: -item[1]))
    return word_sorted


#
# Using z3's String constraint
# 
def wordle1(words, correct_pos, correct_chars, not_in_word, num_sols=0):
    """
    Wordle solver using z3's String constraints.
    It use a scored version of the words list, generated by sort_words()
    and try to get the maximum scored word.
    
    Note: The candidates returned are not all possible candidates, just
          the one that was found.
    """
    # set_param("parallel.enable",True) # slower
    
    # Using Solver() is much faster than using Optimize()
    # s = Optimize() # 27.242s
    # 
    # s = Solver() # 5.35s with default smt.string_solver solver
    s = SolverFor("QF_S") # 1.91s with "smt.string_solver" = "z3str3"
    # s.set("smt.string_solver","z3str3")  # 1.98s with Solver(), 1.91s with SolverFor("QF_S")
    # s.set("smt.string_solver","seq")   # 5.8s with Solver(), 5.95s with SolverFor("QF_S")
    s.set("smt.string_solver","auto")  # 2.0s with Solver(), 1.94s with SolverFor("QF_S")
    # s.set("smt.string_solver","empty") # 5.4s with Solver(),  ???s with SolverFor("QF_S")
    # s.set("smt.string_solver","none")  # not a valid solution
       
    v = Int('v')
    target = String('target')
    # words = [['crane', 219], ['saint', 219], ['brine', 217], ['coast', 217], ['crone', 217], ['boast', 216],
    # ['briny', 216], ['cause', 216], ['crony', 216], ['paint', 215], ['poise', 215], ['slant', 215],
    # ['beast', 214], ['borne', 214], ['clone', 214], ['corny', 214], ['tacit', 111]]
    print("num_words:", len(words))

    # Restricting the domain of v is slower when using QF_S. 1.91s ->2.97s
    # But faster when using Solver(): 5.3s -> 3.01s
    # max_v = max([score for _, score in words])
    # # # print("max_v:",max_v)
    # s.add(v <= max_v)
    # s.add(v >= 0)

    s.add(Length(target) == 5)
    # s.add(InRe(target,Plus(Range("a","z")))) # slower    
    
    s.add(Or([And(target == word, v == score)
               for word,score in words
              ]
             ))



    # Correct positions
    for i,c in enumerate(correct_pos):
        if c != '.':
            # s.add(SubSeq(target, i, 1) == StringVal(c))
            s.add(target.at(i) == StringVal(c)) 

    # Correct characters
    #  - correct_chars[i] must be in word
    #  - but not in position i
    for i,chars in enumerate(correct_chars):
        if chars != "":
            for c in chars:
                # Char in target
                s.add(Contains(target,c) == True)
                # but not in correct position
                # s.add(SubSeq(target, i, 1) != StringVal(c))
                s.add(target.at(i) != StringVal(c)) 

    # Not in target
    for c in not_in_word:
        s.add(Contains(target,c) == False)

    # s.maximize(v) # for Optimize()
    
    num_solutions = 0
    candidates = []
    while s.check() == sat:
        num_solutions += 1
        mod = s.model()
        target_val = mod[target]
        candidates.append(target_val)
        v_val = mod[v]
        if num_sols > 0 and num_sols >= num_solutions:
            break
        s.add(target != target_val, v > v_val)
        
    candidates.reverse()
    return candidates



#
# Using z3's String constraint, another approach
# 
def wordle1b(words, correct_pos, correct_chars, not_in_word, num_sols=0):
    """
    Wordle solver using z3's String constraints.
    This is a variant of wordle1() but here we loop through all words
    in the wordlist and only save those that are accepted.
    This is done using

       # ...
       s.push()
       s.add(target == word)
       if s.check() == sat():
          # add to candidates
       s.pop()

    It is significantly faster than wordle1(): 0.38s vs 1.9s (on the tacit problem).
    Here we don't use the precalculated scores since we sort the candidates on
    return.    
    """
    # set_param("parallel.enable",True) # slower

    # t = Then('psmt','simplify') # testing
    # s = t.solver()
    
    # s = Solver()
    s = SolverFor("ALL")    
    # s = SolverFor("QF_SLIA")    
    # s = SolverFor("QF_S")
    # s.set("smt.string_solver","z3str3")
    # s.set("smt.string_solver","seq") 
    s.set("smt.string_solver","auto")
    # s.set("smt.string_solver","empty")
    # s.set("smt.string_solver","none") 
    # s.set("str.aggressive_length_testing",True) # Just testing...

    target = String('target')

    # For some reason, this makes check()
    # very slow in some instances...
    # s.add(Length(target) == 5)
    # This is OK, but useless
    # s.add(Length(target) >= 5) 

    # t0 = time.time()

    # Correct positions
    for i,c in enumerate(correct_pos):
        if c != '.':
            # s.add(SubSeq(target, i, 1) == c)
            # s.add(SubString(target, i, 1) == StringVal(c))
            s.add(target.at(i) == c) 

    # Correct characters
    #  - correct_chars[i] must be in word
    #  - but not in position i
    for i,chars in enumerate(correct_chars):
        if chars != "":
            for c in chars:
                # Char in target
                s.add(Contains(target,c) == True)

                # It seems to hang if the character is
                # in correct_pos
                if not c in correct_pos:

                    # but not in correct position
                    # s.add(SubSeq(target, i, 1) != c)
                    # s.add(SubString(target, i, 1) != StringVal(c))
                    s.add(target.at(i) != c) 
                
    # Not in target
    for c in not_in_word:
        s.add(Contains(target,StringVal(c)) == False)

    # t1 = time.time()
    # print("Building model:", t1-t0)

    print(s)
    # print(list(s.cube()))
    # print(s.help())
    # print(s.param_descrs())
    # print(s.to_smt2())

    # Loop through all words and check if they are
    # accepted as target words
    # This is the main difference to wordle1()
    candidates = []
    for word in words:
        # print(word)
        s.push()
        s.add(target == word)
        if s.check() == sat:
            candidates.append(word)
        s.pop()

    # t2 = time.time()
    # print("Loop through words:", t2-t1)

    # print(s.statistics())

    candidates = [word for word,score in sort_words(candidates)] 
    # t3 = time.time()
    # print("Sort candidates:", t3-t2)
    
    return candidates



#
# Plain Python approach (no regexes)
#
# Here is no need to presort the words, the sorting is done
# only for the candidates
#
def wordle2(words, correct_pos, correct_chars, not_in_word):
    """
    Wordle solver using plain Python (but no regex)
    """
    candidates = []
    for word in words:
        
        # Correct positions
        if not all([word[i] == c for i,c in enumerate(correct_pos) if c != "."]):
            continue

        # Correct characters
        #  - correct_chars[i] must be in word
        #  - but not in position i        
        found = True            
        for i, chars in enumerate(correct_chars):
            if any([not c in word for c in chars if chars != ""] ) or \
                any([word[i] == c for c in chars  if chars != ""] ):
                    found = False
                    break
        if not found:
            continue
        
        # Not in target word
        if any([c in word for c in not_in_word]):
            continue
                
        # It's a candidate
        candidates.append(word)

    # Return the words in score order 
    return [word for word,score in sort_words(candidates)]

#
# Plain Python program, regex approach.
#
# Here is no need to pre sort the words, the sorting is done
# only for the candidates
#
def wordle3(words, correct_pos, correct_chars, not_in_word):
    """
    Wordle solver using regexes.
    """
    # Compile regexes
    correct_pos_re   = re.compile(correct_pos)
    correct_chars_re = [re.compile(".*"+c+".*") for c in correct_chars]
    not_in_word_re   = re.compile(".*["+not_in_word+"].*")
    
    candidates = []
    for word in words:
        
        # Correct positions
        if not correct_pos_re.match(word):
            continue

        # Correct characters:
        #  - correct_chars[i] must be in word
        #  - but not in position i
        found = True            
        for i, chars in enumerate(correct_chars):
            if chars != "" and (not correct_chars_re[i].match(word) or \
                                correct_chars_re[i].match(word[i])):
                found = False
                break
        if not found:
            continue
        
        # Not in target word
        if not_in_word_re.match(word):
            continue
                
        # It's a candidate
        candidates.append(word)

    # Return the words in score order 
    return [word for word,score in sort_words(candidates)]



def test_wordle1(words):
    """
    Test wordle1(). 
    """

    # This is quite slow
    # Target word: tacit
    # 1) slant
    # We must sort the words with scores
    num_sols = 0
    words = sort_words(words)
    # candidates = wordle1(words,"....t",["","","a","",""],"sln",num_sols)
    # print("candidates:", candidates)
    # -> cadet
    # candidates = wordle1(words,".a..t",["c","","a","",""],"slned",num_sols)
    # print("candidates:", candidates)
    # -> yacht
    # print("candidates:", candidates)
    # This is the benchmark
    candidates = wordle1(words,".ac.t",["c","","a","",""],"slnedyh",num_sols)
    # -> tacit
    print("candidates:", candidates)


    # candidates = wordle1(words,".l...",["","","","",""],"sant")
    # print("candidates:", candidates)

    # slant,tried,truce,trope,trove
    # candidates = wordle1(words,".....",["","","","","t"],"slan")
    # -> tried
    # candidates = wordle1(words,"tr...",["","","","e","t"],"slaniduc")
    # -> trope
    # candidates = wordle1(words,"tr...",["","","","e","t"],"slaniducp")
    # -> trove
    # print("candidates:",candidates)


def test_wordle1b(words):
    """
    Test wordle1(). 
    """

    # This is faster than wordle1()
    # Target word: tacit
    # 1) slant
    # We must sort the words with scores
    num_sols = 0
    # words = sort_words(words)
    # candidates = wordle1b(words,"....t",["","","a","",""],"sln",num_sols)
    # print("candidates:", candidates)
    # -> cadet
    # candidates = wordle1b(words,".a..t",["c","","a","",""],"slned",num_sols)
    # print("candidates:", candidates)
    # -> yacht
    # print("candidates:", candidates)
    # This is the benchmark
    candidates = wordle1b(words,".ac.t",["c","","a","",""],"slnedyh",num_sols)
    # -> tacit
    # print("candidates:", candidates)


    # 
    # candidates = wordle1b(words,".....",["","","","",""],"sant")
    print("candidates:",candidates)



def test_wordle2(words):
    """
    Test wordle2()
    """

    #
    # wordle2(): Plain Python (but no regexes)
    #
    # This is much faster than wordle1(): 0.18s
    candidates = words # we don't need to sort the words (or any scores)
    # candidates = wordle2(candidates,".....",["","","","",""],"") # all words
    # print("candidates:",candidates)
    #
    # first guess: slant
    # candidates = wordle2(candidates,"....t",["","","a","",""],"sln")
    # -> cadet
    # print("candidates:", candidates)
    # candidates = wordle2(candidates,".a..t",["c","","a","",""],"slned")
    # -> yacht
    # print("candidates:", candidates)
    candidates = wordle2(candidates,".ac.t",["c","","a","",""],"slnedyh")
    # -> tacit (success 4/6
    print("candidates:", candidates)

    #
    # thorn
    # total run time: 0.18s (compare with Picat wordle.pi that takes 0.054s)
    # first guess: slant
    # candidates = wordle2(candidates,".....",["","","","n","t"],"sla")
    # -> tenor (not 'thorn' as in my Picat program)
    # print("candicates:",candidates)
    # candidates = wordle2(candidates,"t....",["","","n","no","tr"],"slae")
    # -> thorn (success 3/6)
    # print("candicates:",candidates)


def test_wordle3(words):
    
    #
    # Wordle 3: Using regexes
    #
    candidates = words # we don't need to sort the words (or any scores)
    # candidates = wordle3(candidates,".....",["","","","",""],"") # all words
    # print("candidates:",candidates)
    # first guess: slant
    # candidates = wordle3(candidates,"....t",["","","a","",""],"sln")
    # -> cadet
    # print("candidates:", candidates)
    # candidates = wordle3(candidates,".a..t",["c","","a","",""],"slned")
    # -> yacht
    # print("candidates:", candidates)
    candidates = wordle3(candidates,".ac.t",["c","","a","",""],"slnedyh")
    # -> tacit (success 4/6
    print("candidates:", candidates)

#
# Results on the last 'tacit' run.
#
# candidates: ["tacit"]
# test_wordle1(): 1.8919973373413086
# 
# candidates: ['tacit']
# test_wordle1b(): 0.35560011863708496
# 
# candidates: ['tacit']
# test_wordle2(): 0.0014348030090332031
# 
# candidates: ['tacit']
# test_wordle3(): 0.0006635189056396484
# 
#
t0 = time.time()
words = [word.rstrip() for word in open(word_list).readlines()]
t1 = time.time()
print("Reading words:", t1-t0)

# test_wordle1(words)
t2 = time.time()
# print("test_wordle1():", t2-t1)
# print()

test_wordle1b(words)
t3 = time.time()
print("test_wordle1b():", t3-t2)
print()


test_wordle2(words)
t4 = time.time()
print("test_wordle2():", t4-t3)
print()

test_wordle3(words)
t5 = time.time()
print("test_wordle3():", t5-t4)
