#!/usr/local/bin/python3.11

import os
import sys
import argparse


try:
    import numpy as np
except:
    print('qmc-fit error: numpy is not present on your machine.\n  Please install scipy and retry.')
#end try

try:
    from scipy.optimize import fmin
except ImportError:
    print('qmc-fit error: scipy is not present on your machine.\n  Please install scipy and retry.')
#end try

try:
    import matplotlib.pyplot as plt
    params = {'legend.fontsize':14,'figure.facecolor':'white','figure.subplot.hspace':0.,
          'axes.labelsize':16,'xtick.labelsize':14,'ytick.labelsize':14}
    plt.rcParams.update(params)
    plots_available = True
except (ImportError,RuntimeError):
    plots_available = False
#end try


def find_nexus_modules():
    import sys
    nexus_lib = os.path.abspath(os.path.join(__file__,'..','..','lib'))
    assert(os.path.exists(nexus_lib))
    sys.path.append(nexus_lib)
#end def find_nexus_modules


def import_nexus_module(module_name):
    import inspect
    import importlib
    return importlib.import_module(module_name)
#end def import_nexus_module

# Load Nexus modules
try:
    # Attempt specialized path-based imports.
    #  (The executable should still work even if Nexus is not installed)
    find_nexus_modules()
    versions = import_nexus_module('versions')
    nexus_version = versions.nexus_version
    del versions

    generic = import_nexus_module('generic')
    obj   = generic.obj
    log   = generic.log
    warn  = generic.warn
    error = generic.error
    del generic

    developer = import_nexus_module('developer')
    DevBase = developer.DevBase
    del developer

    numerics = import_nexus_module('numerics')
    jackknife            = numerics.jackknife
    jackknife_aux        = numerics.jackknife_aux
    simstats             = numerics.simstats
    equilibration_length = numerics.equilibration_length
    curve_fit            = numerics.curve_fit
    least_squares        = numerics.least_squares
    num_eos_fit          = numerics.eos_fit
    morse_fit            = numerics.morse_fit
    morse                = numerics.morse
    morse_einf           = numerics.morse_Einf
    morse_re             = numerics.morse_re
    morse_De             = numerics.morse_De
    morse_a              = numerics.morse_a
    morse_width          = numerics.morse_width
    morse_depth          = numerics.morse_depth
    morse_Ee             = numerics.morse_Ee
    morse_k              = numerics.morse_k
    murnaghan            = numerics.murnaghan
    birch                = numerics.birch
    vinet                = numerics.vinet
    murnaghan_pressure   = numerics.murnaghan_pressure
    birch_pressure       = numerics.birch_pressure
    vinet_pressure       = numerics.vinet_pressure
    del numerics
except:
    # Failing path-based imports, import installed Nexus modules.
    from versions import nexus_version
    from generic import obj,log,warn,error
    from developer import DevBase
    from numerics import jackknife,jackknife_aux
    from numerics import simstats,equilibration_length
    from numerics import curve_fit,least_squares
#end try


# Polynomial functions
roots_n     = lambda p, x: np.roots(np.polyder(p[::-1], x))
root_vals_n = lambda p, x: np.polyval(p[::-1], roots_n(p, x))
moment_n    = lambda p, x: np.polyval(np.polyder(p[::-1], x), roots_n(p, x-1))

all_fit_functions = obj(
    ts = obj(
        linear = obj(
            nparam   = 2,
            function = lambda p,t: p[0]+p[1]*t,
            format   = '{0} + {1}*t',
            params   = [('intercept',lambda p: p[0])],
            ),
        quadratic = obj(
            nparam   = 3,
            function = lambda p,t: p[0]+p[1]*t+p[2]*t*t,
            format   = '{0} + {1}*t + {2}*t^2',
            params   = [('intercept',lambda p: p[0])],
            ),
        sqrt = obj(
            nparam   = 3,
            function = lambda p,t: p[0]+p[1]*np.sqrt(t)+p[2]*t,
            format   = '{0} + {1}*sqrt(t) + {2}*t',
            params   = [('intercept',lambda p: p[0])],
            ),
        ),
    u = obj(
        quadratic = obj(
            nparam   = 3,
            function = lambda p,t: p[0]+p[1]*t+p[2]*t*t,
            format   = '{0} + {1}*t + {2}*t^2',
            params   = [('minimum_x',lambda p: roots_n(p, 1)),
                        ('minimum_e',lambda p: root_vals_n(p,1)),
                        ('curvature',lambda p: moment_n(p, 2))],
            ),
        cubic = obj(
            nparam   = 4,
            function = lambda p,t: p[0]+p[1]*t+p[2]*t*t+p[3]*t*t*t,
            format   = '{0} + {1}*t + {2}*t^2 + {3}*t^3',
            params   = [('minimum_x',lambda p: roots_n(p, 1)[moment_n(p,2) > 0]),
                        ('minimum_e',lambda p: root_vals_n(p,1)[moment_n(p,2) > 0]),
                        ('curvature',lambda p: moment_n(p, 2)[moment_n(p,2) > 0])],
            ),            
        quartic = obj(
            nparam   = 5,
            function = lambda p,t: p[0]+p[1]*t+p[2]*t*t+p[3]*t*t*t+p[4]*t*t*t*t,
            format   = '{0} + {1}*t + {2}*t^2 + {3}*t^3 + {4}*t^4',
            params   = [('minimum_x',lambda p: roots_n(p, 1)[moment_n(p,2) > 0]),
                        ('minimum_e',lambda p: root_vals_n(p,1)[moment_n(p,2) > 0]),
                        ('curvature',lambda p: moment_n(p, 2)[moment_n(p,2) > 0])],
            ),            
        ),
    eos = obj(
        morse = obj(
            nparam   = 4,
            function = morse,
            format   = 'De ( (1-e^-a(r-re))^2 - 1 ) + E_infinity',
            params   = [('minimum_x', morse_re), 
                        ('a',         morse_a),
                        ('de',        morse_De),
                        ('einf',      morse_einf),
                        ('well_width',morse_width),
                        ('well_depth',morse_depth),
                        ('minimum_e' ,morse_Ee),
                        ('morse_k'   ,morse_k)]
            ),
        birch = obj(
            nparam   = 4,
            format   = 'E_inf + 9*V_0*B/16*((V_0/V)**(2./3)-1)**2*( 2 + (Bp-4)*((V_0/V)**(2./3)-1) )', 
            function = birch, 
            params   = [('minimum_x', lambda p: p[1]),
                        ('e_inf',     lambda p: p[0]),
                        ('B',         lambda p: p[2]),
                        ('Bp',        lambda p: p[3]),
                        ('pressure' , lambda p, V: birch_pressure(p[1:], V))],
            ),         
        vinet = obj(
            nparam   = 4,
            format   = 'E_inf + 2*V_0*B/(Bp-1)**2*( 2 - (2+3*(Bp-1)*((V/V_0)**(1./3)-1))*exp(-1.5*(Bp-1)*((V/V_0)**(1./3)-1)) ) ', 
            function = vinet, 
            params   = [('minimum_x', lambda p: p[1]),
                        ('e_inf',     lambda p: p[0]),
                        ('B',         lambda p: p[2]),
                        ('Bp',        lambda p: p[3]),
                        ('pressure' , lambda p, V: vinet_pressure(p[1:], V))],
            ),
        murnaghan = obj(
            nparam   = 4,
            format   = 'E_inf + B/Bp*V*((V_0/V)**Bp/(Bp-1)+1)-V_0*B/(Bp-1)', 
            function = murnaghan, 
            params   = [('minimum_x', lambda p: p[1]),
                        ('e_inf',     lambda p: p[0]),
                        ('B',         lambda p: p[2]),
                        ('Bp',        lambda p: p[3]),
                        ('pressure' , lambda p, V: murnaghan_pressure(p[1:], V))],
            ),             
        ),
    )
fit_functions = obj()



def qmcfit(q,E,fname='linear',minimizer=least_squares):
    # ensure data is in proper array format
    if isinstance(E,(list,tuple)):
        E = np.array(E,dtype=float)
    #end if
    Edata = None
    if len(E)!=E.size and len(E.shape)==2:
        E = E.T
        Edata = E
        E     = Edata.mean(axis=0)
    #end if

    # unpack fitting function information
    finfo = fit_functions[fname]
    fitfunc  = finfo.function
    auxfuncs = obj()
    auxres   = obj()
    for name,func in finfo.params:
        auxfuncs[name]=func
    #end for

    # setup initial guess parameters
    if fname=='quartic':
        pp = np.polyfit(q,E,4)
    elif fname=='cubic':
        pp = np.polyfit(q,E,3)
    elif fname=='quadratic':
        pp = np.polyfit(q,E,2)
    else:
        pp = np.polyfit(q,E,1)
    #end if
    p0 = tuple(list(reversed(pp))+(finfo.nparam-len(pp))*[0])

    # get an optimized fit of the means
    pf = curve_fit(q,E,fitfunc,p0,minimizer)

    # obtain jackknife mean+error estimates of fitted parameters
    jcapture = obj()
    pmean,perror = jackknife(data     = Edata,
                             function = curve_fit,
                             args     = [q,None,fitfunc,pf,minimizer],
                             position = 1,
                             capture  = jcapture,
                             )
    
    # obtain jackknife estimates of derived parameters
    if len(auxfuncs)>0:
        psamples = jcapture.jsamples
        for auxname,auxfunc in auxfuncs.items():
            auxres[auxname] = jackknife_aux(psamples,auxfunc)
        #end for
    #end if

    return pf,pmean,perror,auxres
#end def qmcfit



# Reads scalar.dat files and extracts energy series
def process_scalar_files(scalar_files,equils=None,reblock_factors=None,series_start=None):
    if len(scalar_files)==0:
        error('must provide at least one scalar file')
    #end if
    for scalar_file in scalar_files:
        if not os.path.exists(scalar_file):
            error('scalar file does not exist: {0}'.format(scalar_file))
        #end if
        if not scalar_file.endswith('.scalar.dat'):
            error('file must be of type scalar.dat: {0}'.format(scalar_file))
        #end if
    #end for
    if series_start!=None:
        n=0
        for scalar_file in scalar_files:
            filename = os.path.split(scalar_file)[1]
            series = int(filename.split('.')[-3][1:])
            if series==series_start:
                scalar_files = scalar_files[n:]
                break
            #end if
            n+=1
        #end for
    #end if

    if isinstance(equils,(int,np.int_)):
        equils = len(scalar_files)*[equils]
    elif equils is not None and len(equils)!=len(scalar_files):
        error('must provide one equilibration length per scalar file\nnumber of equils provided: {0}\nnumber of scalar files provided: {1}\nequils provided: {2}\nscalar files provided: {3}'.format(len(equils),len(scalar_files),equils,scalar_files))
    #end if

    if isinstance(reblock_factors,(int,np.int_)):
        reblock_factors = len(scalar_files)*[reblock_factors]
    elif reblock_factors is not None and len(reblock_factors)!=len(scalar_files):
        error('must provide one reblocking factor per scalar file\nnumber of reblock_factors provided: {0}\nnumber of scalar files provided: {1}\nreblock_factors provided: {2}\nscalar files provided: {3}'.format(len(reblock_factors),len(scalar_files),reblock_factors,scalar_files))
    #end if

    # extract energy data from scalar files
    Edata = []
    Emean = []
    Eerr  = []
    Ekap  = []
    n = 0
    for scalar_file in scalar_files:
        fobj = open(scalar_file,'r')
        quantities = fobj.readline().split()[2:]
        fobj.close()
        rawdata = np.loadtxt(scalar_file)[:,1:].transpose()
        qdata = obj()
        for i in range(len(quantities)):
            q = quantities[i]
            d = rawdata[i,:]
            qdata[q]  = d
        #end for
        E = qdata.LocalEnergy

        # exclude blocks marked as equilibration
        if equils is not None:
            nbe = equils[n]
        else:
            nbe = equilibration_length(E)
        #end if
        if nbe>len(E):
            error('equilibration cannot be applied\nequilibration length given is greater than the number of blocks in the file\nfile name: {0}\n# blocks present: {1}\nequilibration length given: {2}'.format(scalar_file,len(E),nbe))
        #end if
        E = E[nbe:]

        mean,var,err,kap = simstats(E)
        Emean.append(mean)
        Eerr.append(err)
        Ekap.append(kap)

        Edata.append(E)
        n+=1
    #end for
    Emean = np.array(Emean)
    Eerr  = np.array(Eerr)
    Ekap  = np.array(Ekap)

    # reblock data into target length
    block_targets = []
    if reblock_factors is None:
        # find block targets based on autocorrelation time, if needed
        for n in range(len(Edata)):
            block_targets.append(len(Edata[n])//Ekap[n])
        #end if
    else:
        for n in range(len(Edata)):
            block_targets.append(len(Edata[n])//reblock_factors[n])
        #end if        
    #end if

    bt = np.array(block_targets,dtype=int).min()
    for n in range(len(Edata)):
        E = Edata[n]
        nbe = len(E)%bt
        E = E[nbe:]
        reblock = len(E)//bt
        E.shape = (bt,reblock)
        if reblock>1:
            E = E.sum(1)/reblock
        #end if
        E.shape = (bt,)
        Edata[n] = E
    #end for
    Edata = np.array(Edata,dtype=float)

    return Edata,Emean,Eerr,scalar_files
#end def process_scalar_files



def stat_strings(mean,error):
    d = int(max(2,1-np.floor(np.log(error)/np.log(10.))))
    fmt = '{0:16.'+str(d)+'f}'
    mstr = fmt.format(mean).strip()
    estr = fmt.format(error).strip()
    return mstr,estr
#end def stat_strings


def parse_list(opt,name,dtype,len1=False):
    try:
        if opt[name]!=None:
            opt[name] = np.array(opt[name].split(),dtype=dtype)
            if len1 and len(opt[name])==1:
                opt[name] = opt[name][0]
            #end if
        #end if
    except:
        error('{0} list misformatted: {1}'.format(name,opt[name]))
    #end try
#end def parse_list


def timestep_fit(args):
    opt = obj(**args.__dict__)
    scalar_files = sorted(opt.scalar_files)

    if len(scalar_files)==0:
        log('\n'+parser.format_help().strip()+'\n')
        exit()
    #end if

    if opt.fit_function not in fit_functions:
        error('invalid fitting function: {0}\nvalid options are: {1}'.format(opt.fit_function,sorted(fit_functions.keys())))
    #end if

    if opt.timesteps is None:
        opt.timesteps = ''
    #end if
    parse_list(opt,'timesteps',float)
    parse_list(opt,'equils',int,len1=True)
    parse_list(opt,'reblock_factors',int,len1=True)
    
    # read in scalar energy data
    Edata,Emean,Eerror,scalar_files = process_scalar_files(
        scalar_files    = scalar_files,
        series_start    = opt.series_start,
        equils          = opt.equils,
        reblock_factors = opt.reblock_factors,
        )

    if len(Edata)!=len(opt.timesteps):
        error('must provide one timestep per scalar file\nnumber of timesteps provided: {0}\nnumber of scalar files provided: {1}\ntimeteps provided: {2}\nscalar files provided: {3}'.format(len(opt.timesteps),len(scalar_files),opt.timesteps,scalar_files))
    #end if

    # perform jackknife analysis of the fit
    pf,pmean,perror,auxres = qmcfit(opt.timesteps,Edata,opt.fit_function)

    # print text info about the fit results
    func_info = fit_functions[opt.fit_function]
    pvals = []
    for n in range(len(pmean)):
        pvals.append('({0} +/- {1})'.format(*stat_strings(pmean[n],perror[n])))
    #end for

    log('\nfit function  : '+opt.fit_function)
    log('fitted formula: '+func_info.format.format(*pvals))
    for pname,pfunc in func_info.params:
        pm,pe = stat_strings(*auxres[pname])
        log('{0:<14}: {1} +/- {2}  Ha\n'.format(pname,pm,pe))
    #end for

    # plot the fit (if available)
    if plots_available and not opt.noplot:
        lw = 2
        ms = 10

        ts = opt.timesteps
        tsmax = ts.max()
        E0,E0err = auxres.intercept
        tsfit = np.linspace(0,1.1*tsmax,400)
        Efit  = func_info.function(pmean,tsfit)
        plt.figure()
        plt.plot(tsfit,Efit,'k-',lw=lw)
        plt.errorbar(ts,Emean,Eerror,fmt='b.',ms=ms)
        plt.errorbar([0],[E0],[E0err],fmt='r.',ms=ms)
        plt.xlim([-0.1*tsmax,1.1*tsmax])
        plt.xlabel('DMC Timestep (1/Ha)')
        plt.ylabel('DMC Energy (Ha)')
        plt.show()
    #end if
#end def timestep_fit

def hubbard_u_fit(args):
    opt = obj(**args.__dict__)
    scalar_files = sorted(opt.scalar_files)
    if opt.fit_function not in fit_functions:
        error('invalid fitting function: {0}\nvalid options are: {1}'.format(opt.fit_function,sorted(fit_functions.keys())))
    #end if
    # read in scalar energy data
    Edata,Emean,Eerror,scalar_files = process_scalar_files(
        scalar_files    = scalar_files,
        series_start    = opt.series_start,
        equils          = opt.equils,
        reblock_factors = opt.reblock_factors,
        )
    hubbard_u = True
    if opt.hubbards is not None:
        parse_list(opt,'hubbards',float)
        if len(Edata)!=len(opt.hubbards):
            error('must provide one hubbard_u value per scalar file\nnumber of hubbard_u provided: {0}\nnumber of scalar files provided: {1}\hubbards provided: {2}\nscalar files provided: {3}'.format(len(opt.hubbards),len(scalar_files),opt.hubbards,scalar_files))
        #end if
        x = opt.hubbards
    elif opt.exx is not None:
        hubbard_u = False
        parse_list(opt,'exx',float)
        if len(Edata)!=len(opt.exx):
            error('must provide one hubbard_u value per scalar file\nnumber of hubbard_u provided: {0}\nnumber of scalar files provided: {1}\hubbards provided: {2}\nscalar files provided: {3}'.format(len(opt.hubbards),len(scalar_files),opt.hubbards,scalar_files))
        #end if
        x = opt.exx
    else:
        log("\n Please provide either EXX or Hubbard_U values")
    #end if

    parse_list(opt,'equils',int,len1=True)
    parse_list(opt,'reblock_factors',int,len1=True)
    
    # perform jackknife analysis of the fit
    
    pf,pmean,perror,auxres = qmcfit(x,Edata,opt.fit_function) 

    func_info = fit_functions[opt.fit_function]
    pvals = []
    for n in range(len(pmean)):
        pvals.append('({0} +/- {1})'.format(*stat_strings(pmean[n],perror[n])))
    #end for

    log('\nfit function  : '+opt.fit_function)
    log('fitted formula: '+func_info.format.format(*pvals))
    for pname,pfunc in func_info.params:
        for i in range(len(auxres[pname][0])):
            pm,pe = stat_strings(*np.array(auxres[pname])[:,i])
            unit = ""
            if pname == "minimum_e":
                param = pname
                unit = "Ha"
            elif pname == "minimum_x":
                if hubbard_u:
                    param = "U at minimum"
                    unit = "eV"
                else:
                    param = "EXX at minimum"
                    unit = ''
            else:
                param = pname
            #end if
            log('root {0} {1:<14}: {2} +/- {3} {4}'.format(i+1,param,pm,pe, unit))
    #end for
    # plot the fit (if available)

    if plots_available and not opt.noplot:
        lw = 2
        ms = 10
        ts = x
        tsmin = ts.min()
        tsmax = ts.max()
        tsrange = tsmax-tsmin

        U_min,U_err = auxres.minimum_x
        E_min,E_err = auxres.minimum_e
        real_U_in_range = np.isreal(U_min) & (U_min < tsmax) &  (U_min > tsmin)
        U_min = np.real(U_min[real_U_in_range])
        U_err = np.real(U_err[real_U_in_range])
        E_min = np.real(E_min[real_U_in_range])
        tsfit = np.linspace(0,1.1*tsmax,400)
        Efit  = func_info.function(pmean,tsfit)
        plt.figure()
        plt.plot(tsfit,Efit,'k-',lw=lw)
        plt.errorbar(ts,Emean,Eerror,fmt='b.',ms=ms)
        plt.errorbar(U_min,E_min,xerr=U_err,fmt='r.',ms=ms)
        plt.xlim([tsmin-0.1*tsrange,tsmax + 0.1*tsrange])
        if hubbard_u:
            plt.xlabel('Hubbard U (eV)')
        else:
            plt.xlabel('EXX ratio')
        plt.ylabel('DMC Energy (Ha)')
        plt.show()
    #end if    
#end def hubbard_u_fit 

def eos_fit(args):
    opt = obj(**args.__dict__)
    scalar_files = sorted(opt.scalar_files)
    if opt.fit_function not in fit_functions:
        error('invalid fitting function: {0}\nvalid options are: {1}'.format(opt.fit_function,sorted(fit_functions.keys())))
    #end if
    Edata,Emean,Eerror,scalar_files = process_scalar_files(
        scalar_files    = scalar_files,
        series_start    = opt.series_start,
        equils          = opt.equils,
        reblock_factors = opt.reblock_factors,
        )

    parse_list(opt,'eos',float)
    parse_list(opt,'equils',int,len1=True)
    parse_list(opt,'reblock_factors',int,len1=True)
    
    # perform jackknife analysis of the fit
    x = opt.eos
    
    jcapture = obj()
    auxfuncs = obj()
    auxres   = obj()
    finfo = fit_functions[opt.fit_function]
    if 'params' in finfo.keys():
        for name,func in finfo.params:
            auxfuncs[name]=func
        #end for
    if opt.fit_function == 'morse':
        _, pmean, perror = morse_fit(x, np.transpose(Edata), jackknife=True, auxfuncs=auxfuncs, auxres=auxres, capture=jcapture)
    else:
        _, pmean, perror = eos_fit(x, np.transpose(Edata), type = opt.fit_function, jackknife=True, auxfuncs=auxfuncs, auxres=auxres, capture=jcapture)

    func_info = fit_functions[opt.fit_function]
    pvals = []
    for n in range(len(pmean)):
        pvals.append('({0} +/- {1})'.format(*stat_strings(pmean[n],perror[n])))
    #end for
    log('\nfit function  : '+opt.fit_function)
    log('fitted formula: '+func_info.format.format(*pvals))

    if 'params' in func_info.keys():
        for pname,pfunc in func_info.params:
            pm,pe = stat_strings(*np.array(auxres[pname]))
            log('{0}: {1} +/- {2} '.format(pname,pm,pe))
        #end for
    
    # plot the fit (if available)
    if plots_available and not opt.noplot:
        lw = 2
        ms = 10
        ts = x
        tsmin = ts.min()
        tsmax = ts.max()
        tsrange = tsmax-tsmin
        x_min,x_err = auxres.minimum_x
        if 'minimum_e' in auxres.keys():
            e_min,e_err = auxres.minimum_e
        else:
            e_min = func_info.function(pmean, x_min)
        #end if
        tsfit = np.linspace(tsmin,1.1*tsmax,400)
        Efit  = func_info.function(pmean,tsfit)
        plt.figure()
        plt.plot(tsfit,Efit,'k-',lw=lw)
        plt.errorbar(ts,Emean,Eerror,fmt='b.',ms=ms)
        plt.errorbar(x_min,e_min,xerr=x_err,fmt='r.',ms=ms)
        plt.xlim([tsmin-0.1*tsrange,tsmax + 0.1*tsrange])
        if opt.fit_function == 'morse':
            plt.xlabel('Distance (A)')
        else:
            plt.xlabel('Volume (A^3)')
        #end if
        plt.ylabel('DMC Energy (Ha)')
        plt.show()
    #end if              
#end def eos_fit 

def parse_args():
    """This utility provides a fit to the one-dimensional parameter scans of QMC 
    observables. Currently, the functionality in place is to fit linear/quadratic polynomial
    fits to the timestep VMC/DMC studies and single parameter optimization of trial wavefunctions
    using DMC local energies and quadratic, cubic and quartic fits. 
    """
    
    parser       = argparse.ArgumentParser(description=parse_args.__doc__,
                                           formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('fit_type', choices=['ts', 'u', 'eos'], default = 'ts',
                        help='One dimensional parameter used to fit QMC local energies. Options are ts for timestep and u for hubbard_u parameter fitting'
                        ) 
    valid_fit_functions = set([j for i in all_fit_functions.keys() for j in all_fit_functions[i].keys()])
    parser.add_argument('-f','--fit',dest='fit_function',default='linear',
                        help='Fitting function, options are {0}.'.format(sorted(valid_fit_functions))
                        )                        
    # List of 1-D parameters, mutually exclusive
    parameters = parser.add_mutually_exclusive_group(required=True)
    parameters.add_argument('-t',dest='timesteps',default=None,
                        help='Timesteps corresponding to scalar files, excluding any prior to --series_start'
                        )
    parameters.add_argument('-u',dest='hubbards',default=None, 
                        help='Hubbard U values (eV)'
                        )
    parameters.add_argument('--exx',dest='exx',default=None, 
                        help='EXX ratios'
                        )                        
    parameters.add_argument('--eos',dest='eos',default=None, 
                        help='Structural parameter for EOS fitting'
                        )           
    parser.add_argument('-s', '--series_start',dest='series_start',type=int, default=None,
                        help='Series number for first DMC run.  Use to exclude prior VMC scalar files if they have been provided'
                        )
    parser.add_argument('-e','--equils',dest='equils',default=None,
                        help='Equilibration lengths corresponding to scalar files, excluding any prior to --series_start.  Can be a single value for all files.  If not provided, equilibration periods will be estimated.'
                        )
    parser.add_argument('-b','--reblock_factors',dest='reblock_factors',default=None,
                        help='Reblocking factors corresponding to scalar files, excluding any prior to --series_start.  Can be a single value for all files.  If not provided, reblocking factors will be estimated.'
                       )
    parser.add_argument('--noplot',dest='noplot',action='store_true',default=False,
                        help='Do not show plots.'
                       )
    parser.add_argument('scalar_files', nargs='+',
                        help='Scalar files used in the fit. An explicit list of scalar files with space or a wildcard (e.g. dmc*/dmc.s001.scalar.dat) is acceptable.'
                        ) 

    args = parser.parse_args()
    return parser, args

if __name__=='__main__':
    fit_types = sorted(all_fit_functions.keys())
    parser, args = parse_args()
    if len(args.scalar_files) == 0:
        log('\n'+'Please provide scalar files'+'\n')
        parser.print_help()
        exit()
    #end if
    fit_type = args.fit_type
    if fit_type in fit_types:
        fit_functions.clear()
        fit_functions.transfer_from(all_fit_functions[fit_type])
    else:
        error('unknown fit type: {0}\nvalid options are: {1}'.format(fit_type,fit_types))
    #end if
        
    if fit_type=='ts':
        timestep_fit(args)
    elif fit_type == 'u':
        hubbard_u_fit(args)
    elif fit_type == 'eos':
        eos_fit(args)
    else:
        error('unsupported fit type: {0}'.format(fit_type))
    #end if
#end if