class Func
  def header
    print "module gms_math_derivative\n"
    print "  use datatype\n"
    print "  use mem_manager\n"
    print "  implicit none\n"
  end

  def contains
    print "contains\n"
  end

  def footer
    print "end module gms_math_derivative\n"
  end
end

class Func1
  def header1(name) #name = "x", "y", "z"
    print "  interface d_", name, "\n"
  end


  def header2(name, type) #name = "x", "y", "z"

    if type.index(name) != nil then
      print "    module procedure d_", name ,"_", type, "\n"
    end
  end

  def header3
    print "  end interface\n\n"
  end

  def output(name, type)
    if type.index(name) != nil then
      print "  function d_", name, "_", type, "(input) result(output)\n"
      print "    type(var_", type, "), intent(in) :: input\n"
      print "    type(var_", type, ") :: output\n"
      print "    integer :: new_id\n\n"

      print "    call get_new_id_", type, "(new_id)\n\n"
      print "    output%id = new_id\n\n"

      if name == "x" then 
	print "    output%grid(1) = mod( input%grid(1) + 1, 2 )\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = input%grid(3)\n\n"
	print "    work_", type, "(lb_axis1+output%grid(1) : ub_axis1-1+output%grid(1),:,:,new_id) & \n"
	print "         = (   work_",type, "(lb_axis1+1:ub_axis1,:,:,input%id) &\n"
	print "             - work_", type, "(lb_axis1:ub_axis1-1,:,:,input%id) ) / d", name, "\n"


      elsif name == "y" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = mod( input%grid(2) + 1, 2 )\n"
	print "    output%grid(3) = input%grid(3)\n\n"
	print "    work_", type, "(:,lb_axis2+output%grid(2) : ub_axis2-1+output%grid(2),:,new_id) & \n"
	print "         = (   work_", type, "(:,lb_axis2+1:ub_axis2,:,input%id) &\n"
	print "             - work_", type, "(:,lb_axis2:ub_axis2-1,:,input%id) ) / d", name, "\n"


      elsif name == "z" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = mod( input%grid(3) + 1, 2 )\n\n"	
	print "    work_", type, "(:,:,lb_axis3+output%grid(3) : ub_axis3-1+output%grid(3),new_id) & \n"
	print "         = (   work_", type, "(:,:,lb_axis3+1:ub_axis3,input%id) &\n"
	print "             - work_", type, "(:,:,lb_axis3:ub_axis3-1,input%id) ) / d", name, "\n"


      end
      print "  end function d_", name, "_", type, "\n\n"
    else return
#      print "  function d_", name, "_", type, "(input) result(output)\n"
#      print "    type(var_", type, "), intent(in) :: input\n"
#      print "    type(var_", type, ") :: output\n"
#      print "    integer :: new_id\n\n"

#      print "    stop \"derivative d_", name, "_", type, "\"\n "
#      print "    output%grid = -1\n "
      
#      print "  end function d_", name, "_", type, "\n\n"
    end
  end
end


class Func2
  def header1(name) #name = "x", "y", "z"
    print "  interface d_2", name, "\n"
  end


  def header2(name, type) #name = "x", "y", "z"

    if type.index(name) != nil then
      print "    module procedure d_2", name ,"_", type, "\n"
    end
  end

  def header3
    print "  end interface\n\n"
  end


  def output(name, type)
    if type.index(name) != nil then
      print "  function d_2", name, "_", type, "(input) result(output)\n"
      print "    type(var_", type, "), intent(in) :: input\n"
      print "    type(var_", type, ") :: output\n"
      print "    integer :: new_id\n\n"

      print "    call get_new_id_", type, "(new_id)\n\n"
      print "    output%id = new_id\n\n"

      if name == "x" then 
	print "    output%grid = input%grid\n\n"
	print "    work_", type, "(lb_axis1+1 : ub_axis1-1,:,:,new_id) & \n"
	print "         = (   work_",type, "(lb_axis1+2:ub_axis1,:,:,input%id) &\n"
	print "             - work_", type, "(lb_axis1:ub_axis1-2,:,:,input%id) ) / ( 2.0D0 * d", name, " )\n"


      elsif name == "y" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = input%grid(3)\n\n"
	print "    work_", type, "(:,lb_axis2+1 : ub_axis2-1,:,new_id) & \n"
	print "         = (   work_", type, "(:,lb_axis2+2:ub_axis2,:,input%id) &\n"
	print "             - work_", type, "(:,lb_axis2:ub_axis2-2,:,input%id) ) / (2.0D0 * d", name, " )\n"



      elsif name == "z" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = input%grid(3)\n\n"	
	print "    work_", type, "(:,:,lb_axis3+1 : ub_axis3-1,new_id) & \n"
	print "         = (   work_", type, "(:,:,lb_axis3+2:ub_axis3,input%id) &\n"
	print "             - work_", type, "(:,:,lb_axis3:ub_axis3-2,input%id) ) / ( 2.0D0 * d", name, ")\n"


      end
      print "  end function d_2", name, "_", type, "\n\n"
    else return
#      print "  function d_2", name, "_", type, "(input) result(output)\n"
#      print "    type(var_", type, "), intent(in) :: input\n"
#      print "    type(var_", type, ") :: output\n"
#      print "    integer :: new_id\n\n"

#      print "stop \"derivative d_2",name, "_", type, "\"\n "
#      print "    output%grid = -1\n "
#      print "  end function d_2", name, "_", type, "\n\n"
    end
  end
end

class Func4
  def header1(name) #name = "x", "y", "z"
    print "  interface d_4", name, "\n"
  end


  def header2(name, type) #name = "x", "y", "z"

    if type.index(name) != nil then
      print "    module procedure d_4", name ,"_", type, "\n"
    end
  end

  def header3
    print "  end interface\n\n"
  end



  def output(name, type)
    if type.index(name) != nil then
      print "  function d_4", name, "_", type, "(input) result(output)\n"
      print "    type(var_", type, "), intent(in) :: input\n"
      print "    type(var_", type, ") :: output\n"
      print "    integer :: new_id\n\n"

      print "    call get_new_id_", type, "(new_id)\n\n"
      print "    output%id = new_id\n\n"

      if name == "x" then 
	print "    output%grid = input%grid\n\n"
	print "    work_", type, "(lb_axis1+2 : ub_axis1-2,:,:,new_id) & \n"
	print "         = (   work_",type, "(lb_axis1+4:ub_axis1,:,:,input%id) &\n"
	print "             - work_", type, "(lb_axis1:ub_axis1-4,:,:,input%id) ) / ( 4.0D0 * d", name, " )\n"


      elsif name == "y" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = input%grid(3)\n\n"
	print "    work_", type, "(:,lb_axis2+2 : ub_axis2-2,:,new_id) & \n"
	print "         = (   work_", type, "(:,lb_axis2+4:ub_axis2,:,input%id) &\n"
	print "             - work_", type, "(:,lb_axis2:ub_axis2-4,:,input%id) ) / (4.0D0 * d", name, " )\n"



      elsif name == "z" then
	print "    output%grid(1) = input%grid(1)\n"
	print "    output%grid(2) = input%grid(2)\n"
	print "    output%grid(3) = input%grid(3)\n\n"	
	print "    work_", type, "(:,:,lb_axis3+2 : ub_axis3-2,new_id) & \n"
	print "         = (   work_", type, "(:,:,lb_axis3+4:ub_axis3,input%id) &\n"
	print "             - work_", type, "(:,:,lb_axis3:ub_axis3-4,input%id) ) / ( 4.0D0 * d", name, ")\n"


      end
      print "  end function d_4", name, "_", type, "\n\n"
    else return 
#      print "  function d_4", name, "_", type, "(input) result(output)\n"
#      print "    type(var_", type, "), intent(in) :: input\n"
#      print "    type(var_", type, ") :: output\n"
#      print "    integer :: new_id\n\n"

#      print "    stop \"derivative d_4",name, "_", type, "\"\n "
#      print "    output%grid = -1\n "
#      print "  end function d_4", name, "_", type, "\n\n"
    end
  end
end

list1 = open("func_list1", "r")
list2 = open("func_list2", "r")

func = Func.new
func1 = Func1.new
func2 = Func2.new
func4 = Func4.new

type=[7]
name=[2]

type[0] = "x"
type[1] = "y"
type[2] = "z"
type[3] = "xy"
type[4] = "xz"
type[5] = "yz"
type[6] = "xyz"

name[0] = "x"
name[1] = "y"
name[2] = "z"

func.header

for  i in 0..2 do
    func1.header1(name[i])
  for  j in 0..6 do

    func1.header2(name[i], type[j])
  end
  func1.header3
end

for  i in 0..2 do
    func2.header1(name[i])
  for  j in 0..6 do

    func2.header2(name[i], type[j])
  end
  func2.header3
end

for  i in 0..2 do
    func4.header1(name[i])
  for  j in 0..6 do

    func4.header2(name[i], type[j])
  end
  func4.header3
end



func.contains

for i in 0..2 do
  for j in 0..6 do
    func1.output(name[i], type[j])
    func2.output(name[i], type[j])
    func4.output(name[i], type[j])
  end
end

func.footer
