package plotter

import scala.collection.mutable.ArrayBuffer
import org.jfree.chart.JFreeChart
import org.jfree.chart.axis.LogAxis
import org.jfree.chart.axis.NumberAxis
import org.jfree.data.xy.DefaultXYDataset
import org.jfree.chart.plot.XYPlot
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer
import java.awt.Font

/**
 * A simple class for plotting data sets with JFreeChart.
 * @author Shaofeng Jiang
 */
class Plotter:
  private var _xLogScale = false
  def xLogScale = _xLogScale
  /** Set/unset the x-axis to use logarithmic scale (use "plotter.xLogScale = true" to set log-scale). */
  def xLogScale_=(v: Boolean): Unit = _xLogScale = v

  private var _yLogScale = false
  def yLogScale = _yLogScale
  /** Set/unset the y-axis to use logarithmic scale (use "plotter.yLogScale = true" to set log-scale). */
  def yLogScale_=(v: Boolean): Unit = _yLogScale = v

  /**
   * Plot the argument data sets with JFreeChart.
   * @return the plot as a JFreeChart.
   */
  def plot(xAxisName: String, yAxisName: String, dataSets: DataSet*): JFreeChart =
    require(dataSets.length > 0, "Must provide at least one data set")
    val dataset = new DefaultXYDataset
    for d <- dataSets do
      val unzipped = d.points.unzip
      dataset.addSeries(d.name, Array(unzipped._1, unzipped._2))
    val xAxis = if xLogScale then {new LogAxis(xAxisName)} else {new NumberAxis(xAxisName)}
    xAxis.setTickLabelFont(new Font(Font.SERIF, Font.PLAIN, 16))
    val yAxis = if yLogScale then {new LogAxis(yAxisName)} else {new NumberAxis(yAxisName)}
    yAxis.setTickLabelFont(new Font(Font.SERIF, Font.PLAIN, 16))
    val plot = new XYPlot(dataset, xAxis, yAxis, new XYLineAndShapeRenderer(true, true))
    new JFreeChart(plot)
