class Func
  def header
    print "module gms_math_bar\n"
    print "  use datatype\n"
    print "  use mem_manager\n"
    print "  implicit none\n"
  end

  def contains
    print "contains\n"
  end

  def footer
    print "end module gms_math_bar\n"
  end
end

class Func1
  def header1(name) #name = "x", "y", "z"
    print "  interface bar_", name, "\n"
  end

  def header2(name, type)
    if type.index(name) != nil then
    print "    module procedure bar_", name ,"_", type, "\n"
    end
  end

  def header3
    print "  end interface\n\n"
  end

  def output(name, type)
    if type.index(name) != nil then
      print "  function bar_", name, "_", type, "(input) result(output)\n"
      print "    type(var_", type, "), intent(in) :: input\n"
      print "    type(var_", type, ") :: output\n"
      print "    integer :: new_id\n\n"

      print "    call get_new_id_", type, "(new_id)\n\n"
      print "    output%id = new_id\n\n"

      if name == "x" then 
	print "    output%grid(1) = mod( input%grid(1) + 1, 2 )\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = input%grid(3)\n\n"
	print "    work_", type, "(lb_axis1+output%grid(1) : ub_axis1-1+output%grid(1),:,:,new_id) & \n"
	print "         = (   work_",type, "(lb_axis1+1:ub_axis1,:,:,input%id) &\n"
	print "             + work_", type, "(lb_axis1:ub_axis1-1,:,:,input%id) ) / 2.0D0 \n"


      elsif name == "y" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = mod( input%grid(2) + 1, 2 )\n"
	print "    output%grid(3) = input%grid(3)\n\n"
	print "    work_", type, "(:,lb_axis2+output%grid(2) : ub_axis2-1+output%grid(2),:,new_id) & \n"
	print "         = (   work_", type, "(:,lb_axis2+1:ub_axis2,:,input%id) &\n"
	print "             + work_", type, "(:,lb_axis2:ub_axis2-1,:,input%id) ) / 2.0D0 \n"


      elsif name == "z" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = mod( input%grid(3) + 1, 2 )\n\n"	
	print "    work_", type, "(:,:,lb_axis3+output%grid(3) : ub_axis3-1+output%grid(3),new_id) & \n"
	print "         = (   work_", type, "(:,:,lb_axis3+1:ub_axis3,input%id) &\n"
	print "             + work_", type, "(:,:,lb_axis3:ub_axis3-1,input%id) ) / 2.0D0\n"


      end
      print "  end function bar_", name, "_", type, "\n\n"
    else return
#      print "  function bar_", name, "_", type, "(input) result(output)\n"
#      print "    type(var_", type, "), intent(in) :: input\n"
#      print "    type(var_", type, ") :: output\n"
#      print "    integer :: new_id\n\n"

#      print "    stop \"derivative bar_", name, "_", type, "\"\n "
#      print "    output%grid = -1\n "
      
#      print "  end function bar_", name, "_", type, "\n\n"
    end
  end
end

#list1 = open("func_list1", "r")
#list2 = open("func_list2", "r")

func = Func.new
func1 = Func1.new
#func2 = Func2.new
#func4 = Func4.new

type=[7]
name=[2]

type[0] = "x"
type[1] = "y"
type[2] = "z"
type[3] = "xy"
type[4] = "xz"
type[5] = "yz"
type[6] = "xyz"

name[0] = "x"
name[1] = "y"
name[2] = "z"

func.header

for  i in 0..2 do
  func1.header1(name[i])
  for j in 0..6 do
    func1.header2(name[i], type[j])
  end
  func1.header3
end



func.contains

for i in 0..2 do
  for j in 0..6 do
    func1.output(name[i], type[j])
  end
end

func.footer
