
require "numru/ganalysis/planet"
require "numru/ganalysis/met"
require "numru/ganalysis/sigma_coord"

module NumRu
  module GAnalysis

    # Meterological analysis regarding vertical section, integration, etc.
    module MetZ
      module_function

      # Derive the mass stream function in the pressure coordinate
      # 
      # Applicable both to pressure- and sigma-coordinate input data
      # (the output is always on the pressure coordinate).
      # 
      # ARGUMENTS
      # * v [GPhys] : meridional wind with a vertical dimension (p or sigma)
      #   It must have a latitudinal dimension too. Longitudinal and time
      #   dimensions are optional. If it has a longitudinal dimension,
      #   zonal mean is taken. The order of the dimensions is not restricted.
      # * ps [GPhys] : surface pressure. Its must have the same grid
      #   as v but for the vertical dimension (ps.rank must be v.rank-1)
      # * pcoord [1D VArray](optional) : output vertical coordinate (set if nil)
      # * vs [nil(default) or GPhys]: vs is not needed (neglected)
      #   when v has a sigma coordinate. It is an optional parameter
      #   to specify the surface values of v, when it is in the pressure
      #   coordinate. vs can be omitted (nil), even when v has a pressure
      #   coordinate; in that case, vs is set by interpolating v if ps is 
      #   within the p range of v (e.g. when ps<=1000hPa), or it is naively 
      #   extended (using the bottom values of v) if ps is out of the range 
      #   (e.g. when ps>1000hPa). In other words, the current implementation
      #   assumes that v is available below the surface, as is customary
      #   for reanalysis data.
      def mass_strm_p(v, ps, pcoord=nil, vs=nil)

        pascal = Units["Pa"]
        grav = Met.g.to_f

        #< consolidate the p or sigma coordinate input >

        if zdim = Met.find_prs_d(v)   # substitution, not comparison
          # has a pressure coordinate
          pcv = v.coord(zdim) # pcv is v's p coord, not pcoord from outside.
                              # This is used only to feed c_cap_by_boundary.
          pcoord = pcv.copy if pcoord.nil?  # if not given from outside, use pcv

          pcv_val = pcv.val
          v_val = v.val             # should be NArray or NArrayMiss
          if v_val.is_a?(NArrayMiss)
            misval = 9.9692099683868690e+36
            v_val = v_val.to_na(misval)
          else
            misval = nil
          end
          if pcv_val[0] > pcv_val[-1]
            # reverse the p coordinate to the increasing order
            pcv_val = pcv_val[-1..0]
            v_val = v_val[ *([true]*zdim + [-1..0,false]) ]
          end

          pcv_val = pcv.units.convert2(pcv_val, pascal) if pcv.units!=pascal
          pcv_over_g = pcv_val / grav

          ps_val = ps.val
          ps_val = ps_val.to_na if ps_val.is_a?(NArrayMiss)
          ps_val = ps.units.convert2(ps_val, pascal) if ps.units!=pascal
          ps_over_g = ps_val / grav

          vs_val = vs && vs.val   # nil (default) or vs.val (if vs is given)
          vs_val = vs_val.to_na if vs_val.is_a?(NArrayMiss)

          v_val, p_over_g, nzbound = GPhys.c_cap_by_boundary(v_val, zdim, 
                                    pcv_over_g, true, ps_over_g, vs_val, misval)

        elsif zdim = SigmaCoord.find_sigma_d(v)  # substitution, not comparison
          # has a sigma coordnate
          sig = v.coord(zdim)
          unless pcoord
            pcoord = sig * 1000
            pcoord.units = "hPa"
            pcoord.name = "p"
            pcoord.long_name = "pressure"
            pcoord.put_att("standard_name","air_pressure")
            pcoord.put_att("positive","down")
          end
          nz = sig.length
          nzbound = nil
          ps = ps.convert_units(pascal) if ps.units != pascal
          sig_val = sig.val
          v_val = v.val    # should be NArray, not NArrayMiss (coz sigma)
          if v_val.is_a?(NArrayMiss)
            v_val = v_val.to_na
            mask = v_val.get_mask
          else
            mask = nil
          end
          p_over_g = SigmaCoord.sig_ps2p(ps.val/grav, sig_val, zdim)
        else
          raise ArgumentError, "v does not have a p or sigma coordinate."
        end
        
        #< cumulative vertical integration >
 
        pc_val = pcoord.val
        if pc_val[0] > pc_val[-1]
          # change it to the increasing order
          pc_val = pc_val[-1..0]
          pcoord = pcoord.copy.replace_val(pc_val)
        end
        pc_val = pcoord.units.convert2(pc_val,pascal)

        pc_over_g = pc_val / grav

        rho_v_cum = GPhys.c_cum_integ_irreg(v_val, mask, p_over_g, zdim, nzbound,
                                         pc_over_g, nil)

        #< zonal mean & latitudinal factor >

        lam, phi, lond, latd = Planet.get_lambda_phi(v, false)

        if latd.nil?
          raise(ArgumentError, "v appears not having a latitudinal dimension")
        end
        if lond
          rho_v_cum = rho_v_cum.mean(lond)
          latd -= 1 if lond<latd
        end

        a_cos = NMath.cos(phi.val) * ( 2 * Math::PI * Planet.radius.to_f )
        latd.times{a_cos.newdim!(0)}
        (rho_v_cum.rank - latd -1).times{a_cos.newdim!(-1)}

        mstrm_val = rho_v_cum * a_cos

        #< make a GPhys >

        axes = Array.new
        for d in 0...v.rank
          case d
          when lond
            # lost by zonal mean
          when zdim
            pax = Axis.new().set_pos(pcoord)
            axes.push(pax)
          else
            axes.push(v.axis(d).copy)   # kept
          end
        end
        grid = Grid.new( *axes )

        units = Units["kg.m-1"]    # p/g*a : Pa / (m.s-2) * m = kg.m-1
        units *= v.units
        mstrm_va = VArray.new(mstrm_val, {"long_name"=>"mass stream function",
                                "units"=>units.to_s}, "mstrm")
        mstrm = GPhys.new(grid, mstrm_va)
        mstrm
      end

      # mass stream function on any vertical coordinate
      #
      # Similar to mass_strm_p, but it supports representation to have
      # an arbitrary physical quantity, such as potential temperature,
      # as the vertical coordinate (instead of pressure).
      # 
      # Applicable both to pressure- and sigma-coordinate input data
      # 
      # ARGUMENTS
      # * v [GPhys] : meridional wind with a vertical dimension (p or sigma)
      #   It must have a latitudinal dimension too. Longitudinal and time
      #   dimensions are optional. If it has a longitudinal dimension,
      #   zonal mean is taken. The order of the dimensions is not restricted.
      # * ps [GPhys] : surface pressure. Its must have the same grid
      #   as v but for the vertical dimension (ps.rank must be v.rank-1)
      # * w [GPhys] : Grid-point values (at the same points as v) of the
      #   quantity used to represent the vertical coordinate.
      #   Its shape must be the same as that of v, as a matter of course.
      # * wcoord [1D VArray] : Output vertical coordinate. It must have
      #   the same units as w.
      # * vs [nil(default) or GPhys]: vs is not needed (neglected)
      #   when v has a sigma coordinate. It is an optional parameter
      #   to specify the surface values of v, when it is in the pressure
      #   coordinate. vs can be omitted (nil), even when v has a pressure
      #   coordinate; in that case, vs is set by interpolating v if ps is 
      #   within the p range of v (e.g. when ps<=1000hPa), or it is naively 
      #   extended (using the bottom values of v) if ps is out of the range 
      #   (e.g. when ps>1000hPa). In other words, the current implementation
      #   assumes that v is available below the surface, as is customary
      #   for reanalysis data.
      # * ws [nil(default) or GPhys]: same as vs but for the surface value of w.
      # 
      def mass_strm_any(v, ps, w, wcoord, vs=nil, ws=nil)

        pascal = Units["Pa"]
        grav = Met.g.to_f

        #< check >

        raise(ArgumentError,"v.shape != w.shape")  if v.shape != w.shape
        raise(ArgumentError,"ps.rank != v.rank-1")  if ps.rank != v.rank-1
        raise(ArgumentError,"w.units !~wcoord.units") if w.units !~ wcoord.units

        #< preprare data >

        if zdim = Met.find_prs_d(v)   # substitution, not comparison
          # has a pressure coordinate
          pcv = v.coord(zdim)   # v's p coord
          pcv_val = pcv.val
          v_val = v.val             # should be NArray or NArrayMiss
          v_val = v_val.to_na if v_val.is_a?(NArrayMiss)
          w_val = w.val             # should be NArray or NArrayMiss
          w_val = w_val.to_na if w_val.is_a?(NArrayMiss)
          if pcv_val[0] > pcv_val[-1]
            # reverse the p coordinate to the increasing order
            pcv_val = pcv_val[-1..0]
            v_val = v_val[ *([true]*zdim + [-1..0,false]) ]
            w_val = w_val[ *([true]*zdim + [-1..0,false]) ]
          end

          pcv_val = pcv.units.convert2(pcv_val, pascal) if pcv.units!=pascal
          pcv_over_g = pcv_val / grav

          ps_val = ps.val
          ps_val = ps_val.to_na if ps_val.is_a?(NArrayMiss)
          ps_val = ps.units.convert2(ps_val, pascal) if ps.units!=pascal
          ps_over_g = ps_val / grav

          vs_val = vs && vs.val   # nil (default) or vs.val (if vs is given)
          vs_val = vs_val.to_na if vs_val.is_a?(NArrayMiss)

          ws_val = ws && ws.val   # nil (default) or ws.val (if ws is given)
          ws_val = ws_val.to_na if ws_val.is_a?(NArrayMiss)

          v_val, p_over_g, nzbound = GPhys.c_cap_by_boundary(v_val, zdim, 
                                           pcv_over_g, true, ps_over_g, vs_val)

          w_val, p_over_g, nzbound = GPhys.c_cap_by_boundary(w_val, zdim, 
                                           pcv_over_g, true, ps_over_g, ws_val)

        elsif zdim = SigmaCoord.find_sigma_d(v)  # substitution, not comparison
          # has a sigma coordnate
          sig = v.coord(zdim)
          nz = sig.length
          nzbound = nil
          ps = ps.convert_units(pascal) if ps.units != pascal
          sig_val = sig.val
          v_val = v.val    # should be NArray, not NArrayMiss (coz sigma)
          w_val = w.val
          p_over_g = SigmaCoord.sig_ps2p(ps.val/grav, sig_val, zdim)
        else
          raise ArgumentError, "v does not have a p or sigma coordinate."
        end

        #< cumulative vertical integration >
 
        wc_val = wcoord.val
        if wc_val[0] > wc_val[-1]
          # change it to the increasing order
          wc_val = wc_val[-1..0]
          wcoord = wcoord.copy.replace_val(wc_val)
        end

        rho_v_cum = GPhys.c_cum_integ_irreg(v_val, p_over_g, zdim, nzbound,
                                         wc_val, w_val)

        #< zonal mean & latitudinal factor >

        lam, phi, lond, latd = Planet.get_lambda_phi(v, false)

        if latd.nil?
          raise(ArgumentError, "v appears not having a latitudinal dimension")
        end
        if lond
          rho_v_cum = rho_v_cum.mean(lond)
          latd -= 1 if lond<latd
        end

        a_cos = NMath.cos(phi.val) * ( 2 * Math::PI * Planet.radius.to_f )
        latd.times{a_cos.newdim!(0)}
        (rho_v_cum.rank - latd -1).times{a_cos.newdim!(-1)}

        mstrm_val = rho_v_cum * a_cos

        #< make a GPhys >

        axes = Array.new
        for d in 0...v.rank
          case d
          when lond
            # lost by zonal mean
          when zdim
            wax = Axis.new().set_pos(wcoord)
            axes.push(wax)
          else
            axes.push(v.axis(d).copy)   # kept
          end
        end
        grid = Grid.new( *axes )

        units = Units["kg.m-1"]    # p/g*a : Pa / (m.s-2) * m = kg.m-1
        units *= v.units
        mstrm_va = VArray.new(mstrm_val, {"long_name"=>"mass stream function",
                                "units"=>units.to_s}, "mstrm")
        mstrm = GPhys.new(grid, mstrm_va)
        mstrm
      end

      # Integrate v with p to ps (where v==vs if vs is given)
      #
      # Normally, p and ps are pressure, but they are actually arbitrary.
      # The assumption here is that the ps is the upper cap of p, and
      # the integration with p is from the smallest p up to ps.
      # fact is a factor. E.g., 1/gravity to get mass-weighted integration
      # through dp/g = -\rho dz
      #
      # ARGUMENT
      # * v [GPhys] a multi-dimensional GPhys
      # * pdim [Integer or String] The dimension of p
      # * ps [GPhys] the capping value of p (~surface pressure);
      #   rank must be one smaller than v's (no-missing data allowed)
      # * vs [nil or GPhys] v at ps (shape must be ps's); if nil, v is 
      #   interpolated. (If vs==nil, no extrapolation is made when ps>p.max)
      # * fact [nil or UNumeric or..] factor to be multiplied afterwords
      #   (e.g., 1/Met.g)
      # * name [nil or String] name to be set
      # * long_name [nil or String] long_name to be set
      #
      # RETURN VALUE
      # * a GPhys
      def integrate_w_p_to_ps(v, pdim, ps, vs: nil, fact: nil,
                              name: nil, long_name: nil)
        pdim = v.dim_index(pdim)
        p = v.coord(pdim)

        if ps
          if p.units !~ ps.units
            raise("units mis-match #{p.units} vs #{ps.units}")
          end
          if p.units != ps.units
            p = p.convert_units(ps.units)
          end
          pv = p.val
          psv = ps.val
          if psv.is_a?(NArrayMiss)
            raise("data missing exists in ps") if psv.count_invalid != 0
            psv = psv.to_na
          end
        else
          pv = p.val
          psv = NArray.float(1).fill!(pv.max)
        end

        vv = v.val
        if vv.is_a?(NArrayMiss)
          mask = vv.get_mask
          misval = 9.9692099683868690e+36  # near 15 * 2^119 (as nc fill val)
          vv = vv.to_na(misval)
        else
          mask = nil
          misval = nil
        end

        nzbound = nil

        if ps
          if pv[0] > pv[-1]
            # reverse the p coordinate to the increasing order
            pv = pv[-1..0]
            vv = vv[ *([true]*pdim + [-1..0,false]) ]
          end
          if vs
            vsv = vs.val
            if vsv.is_a?(NArrayMiss)
              raise("data missing exists in vs") if vsv.count_invalid != 0
              vsv = vsv.to_na
            end
          else
            vsv = nil
          end
          vv, pv, nzbound = GPhys.c_cap_by_boundary(vv, pdim,
                                           pv, true, psv, vsv, misval)
          mask_e = NArray.byte(*vv.shape).fill!(1)  # pdim has get longer by one
          mask_e[ *( [true]*pdim + [0..-2,false]) ] = mask
          mask = mask_e
        end

        psv = psv.newdim!(pdim) if ps
        for iidx in [0,20]
          sel = [iidx,50,true,0]
          vi = GPhys.c_cell_integ_irreg(vv, mask, pv, pdim, nzbound, psv, nil)
        end
        osh = v.shape.clone
        osh.delete_at(pdim)
        vi.reshape!(*osh)

        data = VArray.new(vi, v.data, v.name)
        data.units = v.units * p.units
        data = data * fact if fact
        data.name = name if name
        data.long_name = long_name if long_name

        grid = v.grid.delete_axes(pdim)
        grid.set_lost_axes( Array.new )   # reset
        GPhys.new(grid, data)
      end

    end
  end
end

######################################

if $0 == __FILE__
  require "numru/ggraph"
  include NumRu

  v = GPhys::IO.open("../../../testdata/T.jan.nc","T")
  v = v.copy.replace_val( NArray.float(*v.shape).fill!(1.0) )
  ps = GPhys::IO.open("../../../testdata/pres.jan.nc","pres")
  ps = ps.convert_units("Pa")
  pdim = GAnalysis::Met.find_prs_d(v)
  vp = GAnalysis::MetZ.integrate_w_p_to_ps(v, pdim, ps)
  p "ps", ps.val
  p "result 1", vp.val+1000
  vs = ps * 0.0
  vp = GAnalysis::MetZ.integrate_w_p_to_ps(v, pdim, ps, vs: vs)
  p "result 2", vp.val+1000
end
