# coding: utf-8
#!/usr/bin/ruby
require "numru/gphys"
require 'getoptlong'
#require "./gp_dcpam_methods_v1.0"
include NumRu

def print_usage
  <<~USAGE
  Usage : 
    $ ruby vint.rb Ps.nc IN.nc OUT.nc (options)

      Ps.nc  : NetCDF file for surface pressure
      IN.nc  : Input NetCDF file whose vertical level will be converged
      OUT.nc : Output NetCDF file

    options:
      --merge
        Distributed files are used as input. Those files are merged
        and output is one file.
        Note that the name of NetCDF files to be merged is IN_rank??????.nc, 
        if this option is given.
     --varname <variable name>
        [optional] Variable name is given.
     --grav <value>
        [optional] Gravitational acceleration is given.
  USAGE
end

parser = GetoptLong.new

parser.set_options(
  ['--merge', '-m',              GetoptLong::NO_ARGUMENT],
  ['--varname',                  GetoptLong::OPTIONAL_ARGUMENT],
  ['--time_index_s', '--tis',    GetoptLong::REQUIRED_ARGUMENT],
  ['--time_index_e', '--tie',    GetoptLong::REQUIRED_ARGUMENT],
  ['--grav',                     GetoptLong::REQUIRED_ARGUMENT],
#  ['--plev', '-p',               GetoptLong::REQUIRED_ARGUMENT],
)

$OPT_merge = false
$OPT_varname = ""
$OPT_time_index_s = 0
$OPT_time_index_e = -1
$OPT_grav = 9.8
begin
  parser.each_option do |name, arg|
    eval "$OPT_#{name.sub(/^--/, '').gsub(/-/, '_')} = '#{arg}'"
#    print name, ":", arg, "\n"
    if name == "--merge" then
      $OPT_merge = true
    end
    if name == "--time_index_s" then
      $OPT_time_index_s = $OPT_time_index_s.to_i
    end
    if name == "--time_index_e" then
      $OPT_time_index_e = $OPT_time_index_e.to_i
    end
    if name == "--grav" then
      $OPT_grav = $OPT_grav.to_f
    end
  end
rescue
  exit(1)
end
grav = UNumeric[ $OPT_grav, 'm s-2' ]

#print $OPT_merge, "\n"
#print $OPT_varname
#print a_plev, "\n"
#exit


if ARGV.size < 3 then
  puts print_usage
  exit
end


# 惑星表面圧力
ncfn_ps = ARGV[0]
vname_ps = "Ps"
# 入力
ncfn_in = ARGV[1]
# 出力
ncfn_out = ARGV[2]


if File.exist?(ncfn_out) then
  print "File, ", ncfn_out, " exists.\n"
  print "Overwrite the file? (yes/no)\n"
  input = $stdin.gets
  if input.chomp != 'yes' then
    print "STOP\n"
    exit
  end
end

ncfn = ncfn_in
outncfn = ncfn_out

if $OPT_varname.size > 1 then
  vname = $OPT_varname
else
  unless ncfn[-3..-1] == '.nc' then
    print "ERROR : Unexpected extention of file name: ", ncfn, "\n"
    exit
  end
  is = ncfn.rindex("/") != nil ? ncfn.rindex("/") : -1
  is += 1
  ie = -4
  vname = ncfn[is..ie]
end


print "   Input (Ps)    : ", ncfn_ps, "\n"
print "   Input         : ", ncfn, "\n"
print "   Variable name : ", vname, "\n"
print "   Grav. accel.  : ", grav, "\n"
print "   Output        : ", outncfn, "\n"

if $OPT_merge then
  url = ncfn_ps[0..-4] + "_rank000000.nc@" + "sigm"
else
  url = ncfn_ps + "@" + "sigm"
end
gp_sigm = GPhys::IO.open_gturl( url )
na_sigm = gp_sigm.val
na_delsig = na_sigm[0..-2] - na_sigm[1..-1]
kmax = na_delsig.size
na_delsig = na_delsig.reshape!(1,1,kmax,1)
#
if $OPT_merge then
  url = ncfn_ps[0..-4] + "_rank??????.nc@" + vname_ps
else
  url = ncfn_ps + "@" + vname_ps
end
gp_ps = GPhys::IO.open_gturl( url )
#
na_time = gp_ps.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp_ps = gp_ps.cut('time'=>times..timee)
#
if $OPT_merge then
  url = ncfn[0..-4] + "_rank??????.nc@" + vname
else
  url = ncfn + "@" + vname
end
gp = GPhys::IO.open_gturl( url )
#
na_time = gp.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp = gp.cut('time'=>times..timee)

itimes = $OPT_time_index_s
if $OPT_time_index_e < 0 then
  itimee = $OPT_time_index_e + na_time.size
else
  itimee = $OPT_time_index_e
end
ntime = (itimee-itimes+1)

outfile = NetCDF.create(outncfn)
itime = 0
GPhys::NetCDF_IO.each_along_dims_write(gp, outfile, -1) do |sub|
  # https://qiita.com/hokaccha/items/3abd55aa23894b57ffd1#comment-00bb23380f8b3b7cc489
  progress = ((itime+1).to_f/ntime.to_f*100).round(3)
  print "working... "+progress.round(3).to_s+"%\r"
  STDOUT.flush
  if (itime+1) == ntime then
    print "\n"
  end
  gp_vint = ( sub * na_delsig ).sum('sig') * gp_ps[true,true,itime..itime] / grav
  itime += 1
  [gp_vint]
end
outfile.close
