// Union type class -*- c++ -*-

#include "snprintf.h"

#ifdef __GNUC__
# pragma implementation
#endif // __GNUC__
#include "UnionType.h"
#include "UnionValue.h"
#include "ComponentList.h"
#include "Constraint.h"
#include "Printer.h"

/** @file UnionType.C
 * Union data type
 */

/* Copyright  1998-2002 Marko Mkel (msmakela@tcs.hut.fi).

   This file is part of MARIA, a reachability analyzer and model checker
   for high-level Petri nets.

   MARIA is free software; you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   MARIA is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   The GNU General Public License is often shipped with GNU software, and
   is generally kept in a file called COPYING or LICENSE.  If you do not
   have a copy of the license, write to the Free Software Foundation,
   59 Temple Place, Suite 330, Boston, MA 02111 USA. */

class Value&
UnionType::getFirstValue () const
{
  if (myConstraint)
    return *myConstraint->getFirstValue ().copy ();
  return *new class UnionValue
    (*this, 0, (*this)[card_t (0)].getFirstValue ());
}

class Value&
UnionType::getLastValue () const
{
  if (myConstraint)
    return *myConstraint->getLastValue ().copy ();
  return *new class UnionValue
    (*this, getSize () - 1, (*this)[card_t (getSize () - 1)].getLastValue ());
}

bool
UnionType::isAssignable (const class Type& type) const
{
  if (&type == this)
    return true;

  if (type.getKind () != getKind ())
    return myComponents.isAssignable (type);

  return myComponents.isAssignable
    (static_cast<const class UnionType&>(type).myComponents);
}

bool
UnionType::isAlwaysAssignable (const class Type& type) const
{
  if (&type == this)
    return true;

  if (type.getKind () != getKind ())
    return false;

  return myComponents.isAlwaysAssignable
    (static_cast<const class UnionType&>(type).myComponents);
}

bool
UnionType::isConstrained (const class Value& value) const
{
  assert (value.getType ().isAssignable (*this));
  const class UnionValue& v = static_cast<const class UnionValue&>(value);
  assert (v.getIndex () < getSize ());

  return (*this)[v.getIndex ()].isConstrained (v.getValue ());
}

card_t
UnionType::do_getNumValues () const
{
  card_t numValues = 0;
  assert (!myNumValues && myComponents.getSize ());
  myNumValues = new card_t[myComponents.getSize ()];

  for (card_t i = 0; i < myComponents.getSize (); i++) {
    card_t num = myComponents[i].getNumValues ();
    if (num != CARD_T_MAX && num < CARD_T_MAX - numValues)
      myNumValues[i] = numValues += num;
    else
      return CARD_T_MAX;
  }

  return numValues;
}

card_t
UnionType::convert (const class Value& value) const
{
  assert (value.getKind () == Value::vUnion);
  assert (isConstrained (value));
  assert (getNumValues () < CARD_T_MAX);

  if (myConstraint)
    return Type::convert (value);

  const class UnionValue& v = static_cast<const class UnionValue&>(value);
  card_t i = v.getIndex ();
  card_t number = i ? myNumValues[i - 1] : 0;
  number += (*this)[i].convert (v.getValue ());

  assert (number < getNumValues ());
  return number;
}

class Value*
UnionType::convert (card_t number) const
{
  assert (number < getNumValues ());
  assert (getNumValues () < CARD_T_MAX);

  if (myConstraint)
    return Type::convert (number);

  card_t i = 0;
  while (myNumValues[i] <= number)
    i++;

  class Value* v =
    (*this)[i].convert (number - (i ? myNumValues[i - 1] : 0));

  return new class UnionValue (*this, i, *v);
}

class Value*
UnionType::cast (class Value& value) const
{
  for (card_t i = getSize (); i--; ) {
    if (value.getType ().isAssignable ((*this)[i])) {
      if ((*this)[i].isConstrained (value)) {
	value.setType ((*this)[i]);
	return new class UnionValue (*this, i, value);
      }
      else
	break;
    }
  }
  delete &value;
  return NULL;
}

#ifdef EXPR_COMPILE
# include "CExpression.h"
# include "util.h"
# include <stdio.h>

void
UnionType::compile (class StringBuffer& out)
{
  for (card_t i = 0; i < myComponents.getSize (); i++)
    const_cast<class Type&>(myComponents[i]).compile (out);
  Type::compile (out);
}

void
UnionType::compileDefinition (class StringBuffer& out,
			      unsigned indent) const
{
  char ixname[25];

  out.indent (indent), out.append ("struct {\n");
  out.indent (indent + 2), out.append ("unsigned t;\n");
  out.indent (indent + 2), out.append ("union {\n");
  for (card_t i = 0; i < myComponents.getSize (); i++) {
    out.indent (indent + 4), myComponents[i].appendName (out);
    snprintf (ixname, sizeof ixname, " u%u;\n", i);
    out.append (ixname);
  }
  out.indent (indent + 2), out.append ("} u;\n");
  out.indent (indent), out.append ("}");
}

bool
UnionType::compileEqual (class StringBuffer& out,
			 unsigned indent,
			 const char* left,
			 const char* right,
			 bool equal,
			 bool first,
			 bool last,
			 bool backslash) const
{
  size_t llen = strlen (left), rlen = strlen (right);
  char* l = new char[llen + 25];
  char* r = new char[rlen + 25];
  memcpy (l, left, llen);
  memcpy (r, right, rlen);

  if (!first)
    out.indent (indent);
  out.append (left), out.append (".t");
  out.append (equal ? "==" : "!=");
  out.append (right), out.append (".t");
  out.append (backslash
	      ? (equal ? "&&\\\n" : "||\\\n")
	      : (equal ? "&&\n" : "||\n"));
  out.indent (indent++);
  out.append ("("); first = true;

  for (card_t i = myComponents.getSize (); i--; ) {
    if (first)
      first = false;
    else
      out.indent (indent);
    out.append ("(");
    out.append (left), out.append (".t");
    out.append ("==");
    out.append (i);
    out.append (backslash ? "&&\\\n" : "&&\n");
    out.indent (indent + 1);
    out.append ("(");
    snprintf (l + llen, 23, ".u.u%u", i);
    snprintf (r + rlen, 23, ".u.u%u", i);
    if (!myComponents[i].compileEqual (out, indent + 2, l, r, equal,
				       true, true, backslash))
      out.append (equal ? "1" : "0");
    out.append ("))");
    if (i)
      out.append (backslash ? "||\\\n" : "||\n");
  }

  delete[] l; delete[] r;

  out.append (")"), indent--;
  if (!last)
    out.append (backslash
		? (equal ? "&&\\\n" : "||\\\n")
		: (equal ? "&&\n" : "||\n"));

  return true;
}

void
UnionType::compileCompare3 (class StringBuffer& out,
			    const char* condition,
			    const char* component) const
{
  const size_t len = strlen (component);
  char* const newcomp = new char[len + 25];
  char* const offset = newcomp + len;
  memcpy (newcomp, component, len);

  const size_t condlen = condition ? strlen (condition) : 0;
  char* const newcond = new char[condlen ? len + condlen + 28 : len + 26];
  char* const coffset = condlen
    ? newcond + condlen + len + 3
    : newcond + len + 1;
  if (condlen) {
    memcpy (newcond, condition, condlen);
    memcpy (newcond + condlen, "&&l", 3);
    memcpy (newcond + condlen + 3, component, len);
  }
  else
    *newcond = 'l', memcpy (newcond + 1, component, len);

  memcpy (offset, ".t", 3);
  Type::compileLeafCompare3 (out, condition, newcomp);
  for (card_t i = 0; i < myComponents.getSize (); i++) {
    snprintf (coffset, 26, ".t==%u", i);
    snprintf (offset, 25, ".u.u%u", i);
    myComponents[i].compileCompare3 (out, newcond, newcomp);
  }
  delete[] newcond;
  delete[] newcomp;
}

void
UnionType::do_compileSuccessor (class StringBuffer& out,
				unsigned indent,
				const char* lvalue,
				const char* rvalue,
				const char* wrap) const
{
  size_t llen = strlen (lvalue), rlen = strlen (rvalue);
  char* lval = new char[llen + 25];
  char* rval = new char[rlen + 25];
  memcpy (lval, lvalue, llen);
  memcpy (rval, rvalue, rlen);

  out.indent (indent);
  out.append ("do {\n");
  if (!wrap) {
    wrap = "wrap";
    out.indent (indent + 2);
    out.append ("bool_t ");
    out.append (wrap);
    out.append ("=0;\n");
  }

  out.indent (indent += 2);
  out.append ("switch (");
  if (lvalue != rvalue && strcmp (lvalue, rvalue))
    out.append (lvalue), out.append (".t=");
  out.append (rvalue);
  out.append (".t) {\n");

  for (card_t i = 0;; ) {
    snprintf (lval + llen, 25, ".u.u%u", i);
    snprintf (rval + rlen, 25, ".u.u%u", i);

    out.indent (indent);
    out.append ("case ");
    out.append (i);
    out.append (":\n");
    {
      card_t j = i++;
      if (i == myComponents.getSize ()) i = 0;
      myComponents[j].compileSuccessor (out, indent + 2, lval, rval, wrap);
    }
    out.indent (indent + 2);
    out.append ("if (!");
    out.append (wrap);
    out.append (") continue;\n");
    out.indent (indent + 2);
    out.append (wrap);
    out.append ("=0;\n");
    out.indent (indent + 2);
    out.append (lvalue);
    out.append (".t=");
    out.append (i);
    out.append (";\n");
    snprintf (lval + llen, 25, ".u.u%u", i);
    myComponents[i].compileBottom (out, indent + 2, lval);
    if (i) {
      out.indent (indent + 2);
      out.append ("break;\n");
    }
    else
      break;
  }
  out.indent (indent);
  out.append ("}\n");
  out.indent (indent -= 2);
  out.append ("} while (0);\n");

  delete[] lval;
  delete[] rval;
}

void
UnionType::do_compilePredecessor (class StringBuffer& out,
				  unsigned indent,
				  const char* lvalue,
				  const char* rvalue,
				  const char* wrap) const
{
  size_t llen = strlen (lvalue), rlen = strlen (rvalue);
  char* lval = new char[llen + 25];
  char* rval = new char[rlen + 25];
  memcpy (lval, lvalue, llen);
  memcpy (rval, rvalue, rlen);

  out.indent (indent);
  out.append ("do {\n");
  if (!wrap) {
    wrap = "wrap";
    out.indent (indent + 2);
    out.append ("bool_t ");
    out.append (wrap);
    out.append ("=0;\n");
  }

  out.indent (indent += 2);
  out.append ("switch (");
  if (lvalue != rvalue && strcmp (lvalue, rvalue))
    out.append (lvalue), out.append (".t=");
  out.append (rvalue);
  out.append (".t) {\n");

  for (card_t i = 0;; ) {
    snprintf (lval + llen, 25, ".u.u%u", i);
    snprintf (rval + rlen, 25, ".u.u%u", i);

    out.indent (indent);
    out.append ("case ");
    out.append (i);
    out.append (":\n");
    {
      card_t j = i++;
      if (i == myComponents.getSize ()) i = 0;
      myComponents[j].compilePredecessor (out, indent + 2, lval, rval, wrap);
    }
    out.indent (indent + 2);
    out.append ("if (!");
    out.append (wrap);
    out.append (") continue;\n");
    out.indent (indent + 2);
    out.append (wrap);
    out.append ("=0;\n");
    out.indent (indent + 2);
    out.append (lvalue);
    out.append (".t=");
    out.append (i);
    out.append (";\n");
    snprintf (lval + llen, 25, ".u.u%u", i);
    myComponents[i].compileTop (out, indent + 2, lval);
    if (i) {
      out.indent (indent + 2);
      out.append ("break;\n");
    }
    else
      break;
  }
  out.indent (indent);
  out.append ("}\n");
  out.indent (indent -= 2);
  out.append ("} while (0);\n");

  delete[] lval;
  delete[] rval;
}

void
UnionType::compileCast (class CExpression& cexpr,
			unsigned indent,
			const class Type& target,
			const char* lvalue,
			const char* rvalue) const
{
  assert (isAssignable (target));
  if (target.getKind () == Type::tUnion)
    static_cast<const class UnionType&>(target).compileCastFrom
      (cexpr, indent, *this, lvalue, rvalue);
  else {
    class StringBuffer& out = cexpr.getOut ();
    class StringBuffer cond;
    card_t first = CARD_T_MAX;
    for (card_t i = getSize (); i--; ) {
      if (myComponents[i].isAssignable (target)) {
	if (first == CARD_T_MAX)
	  first = i;
	else
	  cond.append ("&&");
	cond.append (rvalue);
	cond.append (".t!=");
	cond.append (i);
      }
    }
    assert (first != CARD_T_MAX);
    out.indent (indent);
    out.append ("if (");
    out.append (cond);
    out.append (")\n");
    cexpr.compileError (indent + 2, errUnion);
    size_t len = strlen (rvalue);
    char* rval = new char[len + 25];
    memcpy (rval, rvalue, len);
    snprintf (rval + len, 25, ".u.u%u", first);
    myComponents[first].compileCast (cexpr, indent, target, lvalue, rval);
    delete[] rval;
    if (const class Constraint* c = target.getConstraint ())
      c->compileCheck (cexpr, indent, lvalue);
  }
}

void
UnionType::compileCastFrom (class CExpression& cexpr,
			    unsigned indent,
			    const class Type& source,
			    const char* lvalue,
			    const char* rvalue) const
{
  class StringBuffer& out = cexpr.getOut ();
  if (this == &source) {
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (rvalue);
    out.append (";\n");
    return;
  }
  for (card_t i = getSize (); i--; ) {
    if (source.isAssignable ((*this)[i])) {
      out.indent (indent);
      out.append (lvalue);
      out.append (".t=");
      out.append (i);
      out.append (";\n");
      size_t len = strlen (lvalue);
      char* lval = new char[len + 25];
      memcpy (lval, lvalue, len);
      snprintf (lval + len, 25, ".u.u%u", i);
      source.compileCast (cexpr, indent, (*this)[i], lval, rvalue);
      delete[] lval;
      if (const class Constraint* c = getConstraint ())
	c->compileCheck (cexpr, indent, lvalue);
      return;
    }
  }
  assert (false);
}

void
UnionType::do_compileConversion (class StringBuffer& out,
				 unsigned indent,
				 const char* value,
				 const char* number,
				 bool add) const
{
  assert (getNumValues () < CARD_T_MAX && getSize ());
  out.indent (indent);
  out.append ("switch ("), out.append (value), out.append (".t) {\n");
  size_t length = strlen (value);
  char* val = new char[length + 25];
  char* offset = val + length;
  memcpy (val, value, length);
  for (card_t i = 0; i < getSize (); i++) {
    out.indent (indent);
    if (i) {
      out.append ("case "), out.append (i), out.append (":\n");
      out.indent (indent + 2), out.append (number);
      out.append (add ? "+=" : "="), out.append (myNumValues[i - 1]);
      out.append (";\n");
    }
    else
      out.append ("default:\n");

    snprintf (offset, 25, ".u.u%u", i);
    (*this)[i].compileConversion (out, indent + 2, val, number, i || add);

    out.indent (indent + 2);
    out.append ("break;\n");
  }
  delete[] val;
  out.indent (indent);
  out.append ("}\n");
}

void
UnionType::compileReverseConversion (class StringBuffer& out,
				     unsigned indent,
				     const char* number,
				     const char* value) const
{
  if (myConstraint)
    Type::compileReverseConversion (out, indent, number, value);
  else {
    size_t length = strlen (value);
    char* val = new char[length + 25];
    char* offset = val + length;
    memcpy (val, value, length);

    bool next = false;

    for (card_t i = getSize (); i--; next = true) {
      out.indent (indent);
      if (next)
	out.append ("else ");
      card_t num = i ? myNumValues[i - 1] : 0;
      if (num) {
	out.append ("if ("), out.append (number); out.append (">=");
	out.append (num), out.append (") ");
      }
      out.append ("{\n");

      out.indent (indent + 2);
      if (num) {
	out.append (number);
	out.append ("-=");
	out.append (num);
	out.append (", ");
      }
      out.append (value), out.append (".t=");
      snprintf (offset, 25, ".u.u%u", i);
      out.append (offset + 4), out.append (";\n");

      (*this)[i].compileReverseConversion (out, indent + 2, number, val);

      out.indent (indent);
      out.append ("}\n");
    }

    delete[] val;
  }
}

void
UnionType::compileEncoder (class CExpression& cexpr,
			   unsigned indent,
			   const char* func,
			   const char* value) const
{
  if (getNumValues () < CARD_T_MAX)
    Type::compileEncoder (cexpr, indent, func, value);
  else {
    size_t length = strlen (value);
    char* val = new char[length + 25];
    char* offset = val + length;
    memcpy (val, value, length);

    class StringBuffer& out = cexpr.getOut ();
    if (getSize () > 1) {
      out.indent (indent);
      out.append (func);
      out.append (" (");
      out.append (value), out.append (".t, ");
      out.append (log2 (getSize ()));
      out.append (");\n");
    }
    out.indent (indent);
    out.append ("switch ("), out.append (value), out.append (".t) {\n");

    for (card_t i = getSize (); i--; ) {
      out.indent (indent);
      out.append ("case "), out.append (i), out.append (":\n");
      snprintf (offset, 25, ".u.u%u", i);
      (*this)[i].compileEncoder (cexpr, indent + 2, func, val);
      out.indent (indent + 2);
      out.append ("break;\n");
    }

    out.indent (indent);
    out.append ("}\n");

    delete[] val;
  }
}

void
UnionType::compileDecoder (class CExpression& cexpr,
			   unsigned indent,
			   const char* func,
			   const char* value) const
{
  if (getNumValues () < CARD_T_MAX)
    Type::compileDecoder (cexpr, indent, func, value);
  else {
    size_t length = strlen (value);
    char* val = new char[length + 25];
    char* offset = val + length;
    memcpy (val, value, length);

    class StringBuffer& out = cexpr.getOut ();
    if (getSize () > 1) {
      out.indent (indent);
      out.append (value);
      out.append (".t = ");
      out.append (func);
      out.append (" (");
      out.append (log2 (getSize ()));
      out.append (");\n");
    }
    else {
      out.indent (indent);
      out.append (value);
      out.append (".t = 0;\n");
    }

    out.indent (indent);
    out.append ("switch ("), out.append (value), out.append (".t) {\n");

    for (card_t i = getSize (); i--; ) {
      out.indent (indent);
      out.append ("case "), out.append (i), out.append (":\n");
      snprintf (offset, 25, ".u.u%u", i);
      (*this)[i].compileDecoder (cexpr, indent + 2, func, val);
      out.indent (indent + 2);
      out.append ("break;\n");
    }

    out.indent (indent);
    out.append ("}\n");
  }
}

#endif // EXPR_COMPILE

void
UnionType::display (const class Printer& printer) const
{
  printer.printRaw ("union ");
  printer.delimiter ('{')++;
  myComponents.display (printer);
  --printer.delimiter ('}');

  if (myConstraint)
    myConstraint->display (printer);
}
