from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import os


def extract_lc(t,L):
    # takes the luminosity history and returns one flash
    Lav = 0.5*(max(L)+min(L))

    i = len(L)-1
    while L[i] > Lav:
        i = i-1
    i1 = i
    while L[i] < Lav:
        i = i-1
    while L[i] > Lav:
        i = i-1
    i2 = i

    t1 = t[i2:i1]
    L1 = L[i2:i1]
       
    if len(t1)>0:
        # return only the part that is within 1% of the peak luminosity
        t1 = t1[L1 > max(L1)-2]
        L1 = L1[L1 > max(L1)-2]
        # set the start time to zero
        t1 = t1 - t1[0]
    
    return t1, 10.0**L1  
    
        
def read_profile(name,dir=''):
    data = np.genfromtxt(dir+'/'+name+'.data', names=True, skip_header=5)
    return data['star_age'], data['log_L']


msol_me = 1.9892e33/5.9764e27
msol_mj = 1.9892e33/1.8986e30
mj = 1.8986e30
me = 5.9764e27
rj = 6.9911e9
Lsun = 3.8418e33
LEdd = 1.48e38/Lsun

fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(1,1,1)

plt.ylabel(r'$L\ [L_\odot]$', fontsize=12)
plt.tick_params(axis='both', which='major', labelsize=12)
plt.tick_params(axis='both', which='minor', labelsize=12)
#ax.set_yscale('log')
plt.xlabel(r'${\rm Time}\ [{\rm yr}]$', fontsize=11)


mass = 0.8
Z = 0.02
t, L = read_profile('history_M%.2f_mdot1e-09_Z%g' % (mass, Z), dir='../grid/history')
t1, L1 = extract_lc(t, L)
plt.plot(t1,L1,label = r'$M=%.2f M_\odot$' % mass)
plt.plot([t1[0],t1[-1]],[mass*LEdd,mass*LEdd], ':')
plt.title(r'$\dot M = 10^{-9}\ M_\odot\ {\rm yr^{-1}}$')
plt.legend(ncol=2)
plt.xlim((-0.1,10))
plt.savefig('lc.pdf',bbox_inches='tight')



if 0:
    masses = [0.51,0.60,0.65,0.70,0.80,0.90,1.00,1.05,1.10,1.15]

    for mass in masses:
        if os.path.isfile('../grid/history/history_M%.2f_mdot1e-09.data' % mass):        
            t, L = read_profile('history_M%.2f_mdot1e-09' % mass, dir='../grid/history')
            t1, L1 = extract_lc(t, L)
            plt.plot(t1,L1,label = r'$M=%.2f M_\odot$' % mass)
    plt.title(r'$\dot M = 10^{-9}\ M_\odot\ {\rm yr^{-1}}$')
    plt.legend(ncol=2)
    plt.xlim((-0.1,10))
    plt.savefig('lc_M.pdf',bbox_inches='tight')

if 0:
    mass = 1.0
    Zs = np.arange(12) * 0.04 + 0.02

    for Z in Zs:
        if os.path.isfile('../grid/history/history_M%.2f_mdot1e-09_Z%g.data' % (mass,Z)):        
            t, L = read_profile('history_M%.2f_mdot1e-09_Z%g' % (mass,Z) , dir='../grid/history')
            t1, L1 = extract_lc(t, L)
            plt.plot(t1,L1,label = r'$Z=%.2f$' % Z)
            #plt.plot(t1,L1/(LEdd*0.8),label = r'$Z=%.2f$' % Z)
    plt.title(r'$M=%.2f\ M_\odot, \dot M = 10^{-9}\ M_\odot\ {\rm yr^{-1}}$' % (mass,))        
    plt.legend(ncol=2)
    plt.xlim((-0.1,2.0))
    #ax.set_xscale('log')
    plt.savefig('lc_Z_M%.2f.pdf' % (mass,),bbox_inches='tight')
