#!/usr/bin/ruby
class Func
  def init(ope, opename)
    @ope = ope
    @opename = opename
  end

  def header
    print "module gms_math_", @opename, "\n"
    print "  use datatype\n"
    print "  use mem_manager\n"
    print "  implicit none\n"
    print "  integer :: ix, iy, iz\n\n"
    print "  interface operator(", @ope, ")\n"
    print "  module procedure ", @opename, "_x_x, ", @opename, "_x_y, ", @opename, "_x_z, ", @opename, "_x_xy, ", @opename, "_x_xz, ", @opename, "_x_yz, ", @opename, "_x_xyz\n"
    print "  module procedure ", @opename, "_y_x, ", @opename, "_y_y, ", @opename, "_y_z, ", @opename, "_y_xy, ", @opename, "_y_xz, ", @opename, "_y_yz, ", @opename, "_y_xyz\n"
    print "  module procedure ", @opename, "_z_x, ", @opename, "_z_y, ", @opename, "_z_z, ", @opename, "_z_xy, ", @opename, "_z_xz, ", @opename, "_z_yz, ", @opename, "_z_xyz\n"
    print "  module procedure ", @opename, "_xy_x, ", @opename, "_xy_y, ", @opename, "_xy_z, ", @opename, "_xy_xy, ", @opename, "_xy_xz, ", @opename, "_xy_yz, ", @opename, "_xy_xyz\n"
    print "  module procedure ", @opename, "_xz_x, ", @opename, "_xz_y, ", @opename, "_xz_z, ", @opename, "_xz_xy, ", @opename, "_xz_xz, ", @opename, "_xz_yz, ", @opename, "_xz_xyz\n"
    print "  module procedure ", @opename, "_yz_x, ", @opename, "_yz_y, ", @opename, "_yz_z, ", @opename, "_yz_xy, ", @opename, "_yz_xz, ", @opename, "_yz_yz, ", @opename, "_yz_xyz\n"
    print "  module procedure ", @opename, "_xyz_x, ", @opename, "_xyz_y, ", @opename, "_xyz_z, ", @opename, "_xyz_xy, ", @opename, "_xyz_xz, ", @opename, "_xyz_yz, ", @opename, "_xyz_xyz\n"
    print "  module procedure ", @opename, "_real_x, ", @opename, "_real_y, ", @opename, "_real_z, ", @opename, "_real_xy, ", @opename, "_real_xz, ", @opename, "_real_yz, ", @opename, "_real_xyz\n"
    print "  module procedure ", @opename, "_x_real, ", @opename, "_y_real, ", @opename, "_z_real, ", @opename, "_xy_real, ", @opename, "_xz_real, ", @opename, "_yz_real, ", @opename, "_xyz_real\n"
    if @opename == "plus" || @opename == "minus" then
        print "  module procedure ", @opename, "_x, ", @opename, "_y, ", @opename, "_z, ", @opename, "_xy, ", @opename, "_xz, ", @opename, "_yz, ", @opename, "_xyz\n"
    end


    print "  end interface\n"
    print "contains\n"

  end

  def output(in1, in2)




    if in1.index("x") != nil then 
      x1 = "ix"
      else 
      x1 = "1"
    end

    if in1.index("y") != nil then 
      y1 = "iy"
      else 
      y1 = "1"
    end

    if in1.index("z") != nil then 
      z1 = "iz"
      else 
      z1 = "1"
    end

    if in1.index("real") != nil then
      x1 = ""
      y1 = ""
      z1 = ""
    end


    if in2.index("x") != nil then 
      x2 = "ix"
      else 
      x2 = "1"
    end

    if in2.index("y") != nil then 
      y2 = "iy"
      else 
      y2 = "1"
    end

    if in2.index("z") != nil then 
      z2 = "iz"
      else 
      z2 = "1"
    end

    if in2.index("real") != nil then
      x2 = ""
      y2 = ""
      z2 = ""
    end

    out = ""

    if (in1 + in2).index("x") != nil then
      out = out + "x"
      x_out = "ix"
    else
      x_out = "1"
    end

    if (in1 + in2).index("y") != nil then
      out = out + "y"
      y_out = "iy"
    else
      y_out = "1"
    end

    if (in1 + in2).index("z") != nil then
      out = out + "z"
      z_out = "iz"
    else
      z_out = "1"
    end

    if in1 == "real" && in2 == "real" then return end 

#begin function output
    if ( @opename == "plus" || @opename == "minus") && in1=="" then
      if in2 == "real" then return end
    print "  function ",@opename,"_", in2,  "(input) result(output)\n"
    print "    type(var_", in2, "), intent(in) :: input \n"
    print "    type(var_", in2, ")             :: output \n\n"
    print "    output=input\n"
    print "    work_", in2,"(:,:,:,input%id) = ", @ope, " work_", in2, "(:,:,:,input%id)\n\n"
    print "  end function ",@opename,"_", in2, "\n"




      return
    end

    if in1== "" then return end


    print "  function ",@opename,"_", in1, "_", in2, "(input1, input2) result(output)\n"

    if in1 != "real" && in2 != "real" then
      print "    type(var_", in1, "), intent(in) :: input1 \n"
      print "    type(var_", in2, "), intent(in) :: input2 \n"
    elsif in1 == "real" then
      print "    real(8), intent(in) :: input1 \n"
      print "    type(var_", in2, "), intent(in) :: input2 \n"
    else
      print "    type(var_", in1, "), intent(in) :: input1 \n"
      print "    real(8), intent(in) :: input2 \n"
    end
    
    print "    type(var_",out,")             :: output \n\n"
    print "    integer :: new_id\n"
    
#+++++++++  GRID CHECK!!!!

    if x1 == "ix"  && x2 == "ix"
      print "    if ( input1%grid(1) /= input2%grid(1) ) stop \"(",@ope,")grid violation\"\n"
    end

    if y1 == "iy"  && y2 == "iy"
      print "    if ( input1%grid(2) /= input2%grid(2) ) stop \"(",@ope,")grid violation\"\n"
    end

    if z1 == "iz"  && z2 == "iz"
      print "    if ( input1%grid(3) /= input2%grid(3) ) stop \"(",@ope,")grid violation\"\n"
    end

 print "\n"

#=========GET NEW ID

    print "    call get_new_id_", out, "(new_id)\n"
    print "    output%id = new_id \n"
    print "\n"
########  START CALC
    if out.index("z")!= nil then
      print "    do iz = lb_axis3, ub_axis3\n"
    end

    if out.index("y")!= nil then
      print "      do iy = lb_axis2, ub_axis2\n"
    end

    if out.index("x")!= nil then
      print "        do ix = lb_axis1, ub_axis1\n"
    end

    print "\n"

    if in1 != "real" && in2 != "real" then
      print "        work_", out
      print "(", x_out, ", ", y_out, ", ",  z_out,", new_id) & \n"
      print "                      = work_", in1
      print "(", x1, ", ", y1, ", ",  z1,", input1%id) & \n"
      print "                      ",@ope
      print " work_", in2, "(", x2, ", ", y2, ", ",  z2,", input2%id)"
      print "\n"
    elsif in1 == "real" then
      print "        work_", out
      print "(", x_out, ", ", y_out, ", ",  z_out,", new_id) & \n"
      print "                      = input1 ", @ope, " work_", in2, "(", x2, ", ", y2, ", ",  z2,", input2%id)"
      print "\n"
    elsif in2 == "real" then
      print "        work_", out
      print "(", x_out, ", ", y_out, ", ",  z_out,", new_id) & \n"
      print "                      = work_", in1
      print "(", x1, ", ", y1, ", ",  z1,", input1%id) ",@ope, "input2"
      print "\n"
    end

    if out.index("z") != nil then
      print "        end do\n"
    end

    if out.index("y") != nil then
      print "      end do\n"
    end

    if out.index("x") != nil then
      print "    end do\n"
    end
#==============  END CALC
    print"\n"

#==========NEW grid INFO+++++++++++++++++++++
    if x1 == "ix"  || x2 == "ix" then
      if x1 == "ix"  && x1 == "ix" then
	print "    output%grid(1) = input1%grid(1)\n"
      elsif x1 == "ix" then 
	print "    output%grid(1) = input1%grid(1)\n"
      else 
	print "    output%grid(1) = input2%grid(1)\n"
      end
    else 
      print"    output%grid(1) = -1\n"
    end


    if y1 == "iy"  || y2 == "iy" then
      if y1 == "iy"  && y2 == "iy" then
	print "    output%grid(2) = input1%grid(2)\n"
      elsif y1 == "iy" then 
	print "    output%grid(2) = input1%grid(2)\n"
      else 
	print "    output%grid(2) = input2%grid(2)\n"
      end
    else 
      print "    output%grid(2) = -1\n"
    end

    if z1 == "iz"  || z2 == "iz" then
      if z1 == "iz"  && z2 == "iz" then
	print "    output%grid(3) = input1%grid(3)\n"
      elsif z1 == "iz" then 
	print "    output%grid(3) = input1%grid(3)\n"
      else 
	print "    output%grid(3) = input2%grid(3)\n"
      end
    else 
        print "    output%grid(3) = -1\n"
    end

    print "  end function ",@opename,"_", in1, "_", in2, "\n\n"
  end

  def footer
    print "end module gms_math_", @opename, "\n"
  end
end

type1=[8]
type2=[7]

type1[0] = type2[0] = "x"
type1[1] = type2[1] = "y"
type1[2] = type2[2] = "z"
type1[3] = type2[3] = "xy"
type1[4] = type2[4] = "xz"
type1[5] = type2[5] = "yz"
type1[6] = type2[6] = "xyz"
type1[7] = type2[7] = "real"
type1[8] = ""

func = Func.new

func.init("**", "power")
func.header
for i in 0..8 do
  for j in 0..7 do
    func.output(type1[i], type2[j])
  end
end
func.footer
