#pragma warning( disable : 4503 4786 4800)

#include <iostream>
#include <algorithm>

extern "C"
  {
  #include "lua.h"
  #include "lauxlib.h"
  }

#include "trie.h"

using namespace std;

// Helper function: StringToLower
static std::string StringToLower(const std::string & str)
{
  std::string result;

  result.reserve(str.length());

  for (std::string::size_type i = 0; i < str.length(); i++)
  {
    result[i] = tolower(str[i]);
  }

  return result;
}

LexicalTrie::LexicalTrie(bool caseSensitive)
{
  root_ = new TrieNode();
  numWords_ = 0;
  caseSensitive_ = caseSensitive;
}
LexicalTrie::~LexicalTrie()
{
  delete root_;
}

void LexicalTrie::addWord(const std::string & word)
{
  bool added;

  if (caseSensitive_)
    added = root_->addWord(word);
  else
    added = root_->addWord(StringToLower(word));
  
  if (added)
    numWords_++;
}

void LexicalTrie::addWordsFromStream(std::istream & in)
{
  while ( true )
  {
    string word;
    getline(in, word);
    if ( !in.fail() )
    {
      if ( word != "" )
        this->addWord(word);
    }
    else
      break;
  };
}

bool LexicalTrie::containsPrefix(const std::string & prefix)
{
  if (caseSensitive_)
    return root_->containsPrefix(prefix);
  else
    return root_->containsPrefix(StringToLower(prefix));
}

bool LexicalTrie::containsWord(const std::string & word)
{
  if (caseSensitive_)
    return root_->containsWord(word);
  else
    return root_->containsWord(StringToLower(word));
}

void LexicalTrie::matchRegExp(const std::string & exp, std::set<std::string> & results)
{
  if (caseSensitive_)
    root_->matchRegExp(results, exp);
  else
    root_->matchRegExp(results, StringToLower(exp));
}

int LexicalTrie::numWords()
{
  return numWords_;
}

void LexicalTrie::removeWord(const std::string & word)
{
  bool removed;

  if (caseSensitive_)
    removed = root_->removeWord(word, removed);
  else
    removed = root_->removeWord(StringToLower(word), removed);

  if (removed)
    numWords_--;
}

void LexicalTrie::suggestCorrections(const std::string & word, int maxEditDistance,
                std::set<LexicalTrie::Correction> & results)
{
  if (caseSensitive_)
    root_->suggestCorrections(results, word, maxEditDistance);
  else
    root_->suggestCorrections(results, StringToLower(word), maxEditDistance);
}

void LexicalTrie::writeToStream(ostream & out)
{
  root_->writeToStream(out);
}

void LexicalTrie::writeToSet(std::set<std::string> & out)
  {
  root_->writeToSet(out);
  }

/*
 * Helper function to handle keeping track of the correction set
 * Ensures that we only keep the lowest edit distance for a correction
 */
void LexicalTrie::AddToCorrections(std::set<Correction> & theSet,
                           const std::string & word, int editDistance)
{
  std::set<Correction>::iterator it = theSet.find(Correction(word, editDistance));

  if (it == theSet.end())
  {
    // Not found... so add it
    theSet.insert(Correction(word, editDistance));
  }
  else
  {
    // Found... update it to the minimum edit distance
    int prevDist = it->editDistance_;
    theSet.erase(it);
    theSet.insert( Correction(word, std::_cpp_min(prevDist, editDistance)) );
  }
}


/*
==================================
   ____                        __ 
  / __/_ _____  ___  ___  ____/ /_
 _\ \/ // / _ \/ _ \/ _ \/ __/ __/
/___/\_,_/ .__/ .__/\___/_/  \__/ 
        /_/  /_/                  
  _______                   
 / ___/ /__ ____ ___ ___ ___
/ /__/ / _ `(_-<(_-</ -_|_-<
\___/_/\_,_/___/___/\__/___/
                            
==================================
*/

/*********************************
 * TrieNode                     *
 *********************************/

// Constructor
LexicalTrie::TrieNode::TrieNode()
{
  isWord_ = false;
}
// Destructor
LexicalTrie::TrieNode::~TrieNode()
{
  // Clear children
  for (std::vector<LetterTriePair>::iterator it = letters_.begin();
     it != letters_.end();
     it++)
  {
    delete it->trie_;
  }
}

bool LexicalTrie::TrieNode::addWord(const std::string & word)
{
  // recursive base case:
  if ( word == "" )
  {
    bool added = false;

    // was this not already a word?
    if (!isWord_)
      added = true;

    isWord_ = true;
    return added;
  }

  LetterTriePair * pair = findLetterPair(word[0]);
  if (pair)
  {
    // pair exists so update it...
    return pair->trie_->addWord(word.substr(1));
  }
  else
  {
    // pair doesn't exist, so create it
    TrieNode * newTrie = new TrieNode();

    // add the word recursively to the new trie
    newTrie->addWord(word.substr(1));

    letters_.push_back(LetterTriePair(word[0], newTrie));

    // keep the vector sorted
    sort(letters_.begin(), letters_.end());

    // in this case, the word was added because we didn't have
    // this branch to begin with...
    return true;
  }
}

bool LexicalTrie::TrieNode::containsPrefix(const std::string & prefix)
{
  // recursive base case
  if ( prefix == "" )
    return true;

  LetterTriePair * pair = this->findLetterPair(prefix[0]);
  if ( !pair )
    return false; // letter doesn't exist - prefix not in trie
  else
    return pair->trie_->containsPrefix( prefix.substr(1) );
}

bool LexicalTrie::TrieNode::containsWord(const std::string & word)
{
  if ( word == "" )
  {
    if (isWord_)
      return true;
    else
      return false;
  }

  LetterTriePair * pair = this->findLetterPair(word[0]);
  if ( !pair )
    return false;
  else
    return pair->trie_->containsWord(word.substr(1));
}

/*void LexicalTrie::TrieNode::Print(int depth)
{
  //PrintSpaces(depth);
  cout << (this->isWord == 0 ? "Word:  No" : "Word: Yes");
  cout << " - " << this->Letters.GetSize() << " children, ";
  cout << this->Letters.GetCapacity() << " capacity.";
  cout << endl;
  for ( size_t pos = 0; pos < Letters.GetSize(); pos++ )
  {
    PrintSpaces(depth+1);
    cout << Letters[pos].Letter << ": ";
    Letters[pos].Trie->Print(depth + 1);
  }
}*/

bool LexicalTrie::TrieNode::removeWord(const std::string & word, bool & removed)
{
  if ( word == "" )
  {
    // if this already was a word, mark that we removed it
    if (isWord_)
      removed = true;

    this->isWord_ = false;

    // are we empty?
    if ( letters_.size() == 0 )
      return true; // true: delete the node
    else
      return false; // false: don't delete - still has children
  }

  LetterTriePair * pair = this->findLetterPair(word[0]);
  if ( !pair )
    return false;

  if ( pair->trie_->removeWord(word.substr(1), removed) )
  {
    delete pair->trie_;

    std::vector<LetterTriePair>::iterator it = find(letters_.begin(), letters_.end(), *pair);

    letters_.erase(it, it+1);

    // We removed the node... maybe we have no children left?
    if ( letters_.size() == 0 && isWord_ == false )
      return true;
  }

  return false;
}

void LexicalTrie::TrieNode::writeToStream(ostream & out, const std::string & soFar)
{
  if ( isWord_ )
    out << soFar << endl;

  for ( std::vector<LetterTriePair>::iterator it = letters_.begin();
      it != letters_.end(); it++ )
  {
    it->trie_->writeToStream(out, soFar + it->letter_);
  }
}

void LexicalTrie::TrieNode::writeToSet(std::set<std::string> & output, const std::string & soFar)
{
  if ( isWord_ )
    output.insert(soFar);

  for ( std::vector<LetterTriePair>::iterator it = letters_.begin();
      it != letters_.end(); it++ )
  {
    it->trie_->writeToSet(output, soFar + it->letter_);
  }
}

void LexicalTrie::TrieNode::matchRegExp(std::set<std::string> & resultSet,
                                const std::string & pattern,
                const std::string &soFar)
{
  // do the pattern and string match? (i.e. no wildcards)
  if ( pattern == "" && soFar != "" )
  {
    if ( this->isWord_ )
      resultSet.insert(soFar);
    return;
  }

  std::vector<LetterTriePair>::iterator it;

  switch ( pattern[0] )
  {
    case '*':
      // Try matching 1 or more characters
      for (it = letters_.begin(); it != letters_.end(); it++)
      {
        it->trie_->matchRegExp(resultSet, pattern,
          soFar + it->letter_);
      }

      // Try matching 0 characters
      this->matchRegExp(resultSet, pattern.substr(1), soFar);
      break;

    case '?':
      // Try matching no character
      this->matchRegExp(resultSet, pattern.substr(1), soFar);
      
      // Try matching one character
      for (it = letters_.begin(); it != letters_.end(); it++)
      {
        it->trie_->matchRegExp(resultSet, pattern.substr(1),
          soFar + it->letter_);
      }
      break;

    default:
      // just make sure the letter matches - see if we have that letter from here
      LetterTriePair * pair = this->findLetterPair( pattern[0] );
      if (pair)
      {
        // we have it - remove it from pattern,
        // add to soFar, and continue
        pair->trie_->matchRegExp(resultSet,
          pattern.substr(1), soFar + pair->letter_);
      }
      else
      {
        // we don't have the letter... abort
        return;
      }
      break;
  }
}

void LexicalTrie::TrieNode::
    suggestCorrections(std::set<LexicalTrie::Correction> & results,
               const std::string & word, int maxEditDistance,
               int editsUsed, const std::string & soFar)
{
  if ( this->isWord_ && word == "" )
  {
    LexicalTrie::AddToCorrections(results, soFar, editsUsed);
  }

  if ( word == "" && editsUsed == maxEditDistance )
    return; // can't go on...

  if ( editsUsed == maxEditDistance )
  {
    // We've used up all our changes... so we can only go on if the
    // letters match. First, see if we have the letter at this node...
    LetterTriePair * pair = this->findLetterPair( word[0] );
    if ( pair )
    {
      // We have it...
      pair->trie_->suggestCorrections(results, word.substr(1),
        maxEditDistance, editsUsed, soFar + word[0]);
    }
  }
  else
  {
    // First try removing a letter (effectively skipping one)
    if ( word != "" )
    {
      this->suggestCorrections(results, word.substr(1),
        maxEditDistance, editsUsed + 1, soFar );
    }

    // Try every child...
    for ( std::vector<LetterTriePair>::iterator it = letters_.begin();
        it != letters_.end(); it++ )
    {
      // Only do the following if the word is not empty
      if ( word != "" )
      {
        // If the letter matches, then try doing nothing and moving on
        if ( it->letter_ == word[0] )
        {
          it->trie_->suggestCorrections(results, word.substr(1),
            maxEditDistance, editsUsed, soFar + word[0] );
        }

        // Then try changing a letter - skip one letter, but try adding
        // every child letter instead
        it->trie_->suggestCorrections(results, word.substr(1),
          maxEditDistance, editsUsed + 1, soFar + it->letter_ );
      }

      // Then try adding a letter - do not skip a letter, and try adding
      // every child letter
      it->trie_->suggestCorrections(results, word,
        maxEditDistance, editsUsed + 1, soFar + it->letter_ );
    }
  }
}

/* TrieNode::FindLetterPair
 * -------------------------
 * Search a trie node for a given letter child node. Return the pair structure
 * (that contains both the letter and the pointer to the child trie.)
 */
LexicalTrie::LetterTriePair * LexicalTrie::TrieNode::findLetterPair(char letter)
{
  for (std::vector<LetterTriePair>::iterator it = letters_.begin();
     it != letters_.end(); it++ )
  {
    if ( it->letter_ == letter )
      return &(*it);
  }
  return NULL;
}

//----------------------- begin Lua stuff ----------------------------

const char trie_handle[] = "trie_handle";

// make string table item
static void MakeTableItem (lua_State *L, const char * name, const string & str)
  {
  lua_pushstring (L, name);
  lua_pushstring (L, str.c_str ());
  lua_rawset(L, -3);
  }

// make number table item
static void MakeTableItem (lua_State *L, const char * name, const int n)
  {
  lua_pushstring (L, name);
  lua_pushnumber (L, n);
  lua_rawset(L, -3);
  }

static LexicalTrie * Ltrie_gettrie  (lua_State *L)
{
  LexicalTrie **ud = (LexicalTrie **) luaL_checkudata (L, 1, trie_handle);
  luaL_argcheck(L, ud != NULL && *ud != NULL, 1, "trie userdata expected");
  return *ud;
  }

static int Ltrie_add(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);

  string word (luaL_checkstring (L, 2));
  pTrie->addWord (word);
  return 0;
  } // end of Ltrie_add

static int Ltrie_remove(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  string word (luaL_checkstring (L, 2));
  pTrie->removeWord (word);
  return 0;
  } // end of Ltrie_remove

static int Ltrie_contains(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  string word (luaL_checkstring (L, 2));
  lua_pushboolean (L, pTrie->containsWord (word));
  return 1;
  } // end of Ltrie_contains

static int Ltrie_prefix(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  string word (luaL_checkstring (L, 2));
  lua_pushboolean (L, pTrie->containsPrefix (word));
  return 1;
  } // end of Ltrie_prefix

static int Ltrie_save(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  std::set<std::string> out;

  pTrie->writeToSet (out);

  lua_createtable (L, out.size (), 0);
  
  int i = 1;

  // copy set into table
  for (std::set<std::string>::const_iterator it = out.begin (); 
       it != out.end (); it++)
         {
         lua_pushstring (L, it->c_str ());
         lua_rawseti (L, -2, i++);
         }

  return 1;    // the table
  } // end of Ltrie_save

static int Ltrie_load(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  luaL_checktype (L, 2, LUA_TTABLE);
  lua_pushnil (L);  // first key
  while (lua_next (L, 2) != 0)
    {
    string word (luaL_checkstring (L, -1));
    pTrie->addWord (word);
    lua_pop (L, 1);   // remove word, leave key    
    }
  return 0;    
  } // end of Ltrie_load

static int Ltrie_count(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  lua_pushinteger (L, pTrie->numWords ());
  return 1;    // the count
  } // end of Ltrie_count

static int Ltrie_match(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  string regexp = (luaL_checkstring (L, 2));
  std::set<std::string> results;

  pTrie->matchRegExp (regexp, results);
  
  lua_createtable (L, results.size (), 0);
  
  int i = 1;

  // copy set into table
  for (std::set<std::string>::const_iterator it = results.begin (); 
       it != results.end (); it++)
         {
         lua_pushstring (L, it->c_str ());
         lua_rawseti (L, -2, i++);
         }


  return 1;    // the table
  } // end of Ltrie_match

static int Ltrie_corrections(lua_State *L)
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  string regexp = (luaL_checkstring (L, 2));
  int distance = (luaL_checkinteger (L, 3));
  std::set<LexicalTrie::Correction> results;

  pTrie->suggestCorrections (regexp, distance, results);
  
  lua_createtable (L, results.size (), 0);
  
  int i = 1;

  // copy set into table
  for (std::set<LexicalTrie::Correction>::const_iterator it = results.begin (); 
       it != results.end (); it++)
         {
         lua_createtable (L, 0, 2);
         MakeTableItem (L, "word", it->suggestedWord_);
         MakeTableItem (L, "distance", it->editDistance_);
         lua_rawseti (L, -2, i++);
         }

  return 1;    // the table
  } // end of Ltrie_corrections

// done with the trie, delete it
static int Ltrie_gc (lua_State *L) {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  delete pTrie;
  // set userdata to NULL, so we don't try to use it now
  LexicalTrie **ud = (LexicalTrie **) luaL_checkudata (L, 1, trie_handle);
  *ud = NULL;
  return 0;
  }  // end of Ltrie_gc

// tostring helper
static int Ltrie_tostring (lua_State *L) 
  {
  LexicalTrie *pTrie = Ltrie_gettrie (L);
  lua_pushstring(L, "trie");
  return 1;
}  // end of Ltrie_tostring

//----------------------- create a new trie object ----------------------------

static int Ltrie_new(lua_State *L)
{
  LexicalTrie *pTrie = new LexicalTrie (luaL_optnumber (L, 1, 1)); // case-sensitive flag
  LexicalTrie **ud = (LexicalTrie **)lua_newuserdata(L, sizeof (LexicalTrie *));
  luaL_getmetatable(L, trie_handle);
  lua_setmetatable(L, -2);
  *ud = pTrie;    // store pointer to this trie in the userdata
  return 1;
  }  // end of Ltrie_new



static const luaL_reg triemeta[] = {
  {"add",        Ltrie_add},         // add one word
  {"remove",     Ltrie_remove},      // remove one word
  {"contains",   Ltrie_contains},    // is word there?
  {"prefix",     Ltrie_prefix},      // is prefix there?
  {"save",       Ltrie_save},        // save trie to table
  {"load",       Ltrie_load},        // load trie from table
  {"count",      Ltrie_count},       // number of words in the trie
  {"match",      Ltrie_match},       // match regular expression
  {"corrections",Ltrie_corrections}, // return table of corrections
  {"__gc",       Ltrie_gc},
  {"__tostring", Ltrie_tostring},
  {NULL, NULL}
};


/* Open the library */

static const luaL_reg trielib[] = {
  {"new",     Ltrie_new},
  {NULL, NULL}
};

static void createmeta(lua_State *L, const char *name)
{
  luaL_newmetatable(L, name);   /* create new metatable */
  lua_pushliteral(L, "__index");
  lua_pushvalue(L, -2);         /* push metatable */
  lua_rawset(L, -3);            /* metatable.__index = metatable */
}

LUALIB_API int luaopen_trie(lua_State *L)
{
  createmeta(L, trie_handle);
  luaL_register (L, NULL, triemeta);
  lua_pop(L, 1);
  luaL_register (L, "trie", trielib);
  return 1;
}

