diff --git a/kernels/grid.go b/kernels/grid.go index 416faf01..fceb08e5 100644 --- a/kernels/grid.go +++ b/kernels/grid.go @@ -57,6 +57,9 @@ type Wavefront struct { InitExecMask uint64 WorkItems []*WorkItem + //for sampling + Finishtime sim.VTimeInSec + Issuetime sim.VTimeInSec } // NewWavefront returns a new Wavefront. diff --git a/samples/runner/runner.go b/samples/runner/runner.go index 77839b01..9866e04a 100644 --- a/samples/runner/runner.go +++ b/samples/runner/runner.go @@ -15,6 +15,7 @@ import ( "github.com/sarchlab/akita/v3/tracing" "github.com/sarchlab/mgpusim/v3/benchmarks" "github.com/sarchlab/mgpusim/v3/driver" + "github.com/sarchlab/mgpusim/v3/samplinglib" "github.com/tebeka/atexit" ) @@ -64,7 +65,7 @@ func (r *Runner) Init() *Runner { r.parseGPUFlag() log.SetFlags(log.Llongfile | log.Ldate | log.Ltime) - + samplinglib.InitSampledEngine() if r.Timing { r.buildTimingPlatform() } else { diff --git a/samplinglib/stableengine.go b/samplinglib/stableengine.go new file mode 100644 index 00000000..3d554161 --- /dev/null +++ b/samplinglib/stableengine.go @@ -0,0 +1,82 @@ +// Package samplinglib provides tools for performing sampling simulation +package samplinglib + +import ( + "github.com/sarchlab/akita/v3/sim" +) + +// WFFeature is used for recording the runtime info +type WFFeature struct { + Issuetime sim.VTimeInSec + Finishtime sim.VTimeInSec +} + +// StableEngine is used to detect if the feature detecting is stable or not +type StableEngine struct { + issuetimeSum sim.VTimeInSec + finishtimeSum sim.VTimeInSec + intervaltimeSum sim.VTimeInSec + mixSum sim.VTimeInSec + issuetimeSquareSum sim.VTimeInSec + rate float64 + granulary int + Wffeatures []WFFeature + boundary float64 + enableSampled bool + predTime sim.VTimeInSec +} + +// Analysis the data +func (se *StableEngine) Analysis() { + rateBottom := sim.VTimeInSec(se.granulary)*se.issuetimeSquareSum - se.issuetimeSum*se.issuetimeSum + rateTop := sim.VTimeInSec(se.granulary)*se.mixSum - se.issuetimeSum*se.finishtimeSum + rate := float64(rateTop / rateBottom) + se.rate = rate + boundary := se.boundary + se.predTime = se.intervaltimeSum / sim.VTimeInSec(se.granulary) + if rate >= (1-boundary) && rate <= (1+boundary) { + se.enableSampled = true + } else { + se.enableSampled = false + } +} + +// Reset all information +func (se *StableEngine) Reset() { + se.Wffeatures = nil + se.issuetimeSum = 0 + se.finishtimeSum = 0 + se.intervaltimeSum = 0 + se.mixSum = 0 + se.issuetimeSquareSum = 0 + se.predTime = 0 + se.enableSampled = false +} + +// Collect data +func (se *StableEngine) Collect(issuetime, finishtime sim.VTimeInSec) { + wffeature := WFFeature{ + Issuetime: issuetime, + Finishtime: finishtime, + } + + se.Wffeatures = append(se.Wffeatures, wffeature) + se.issuetimeSum += issuetime + se.finishtimeSum += finishtime + se.mixSum += finishtime * issuetime + se.issuetimeSquareSum += issuetime * issuetime + se.intervaltimeSum += (finishtime - issuetime) + if len(se.Wffeatures) == se.granulary { + se.Analysis() + ///delete old data + wffeature2 := se.Wffeatures[0] + se.Wffeatures = se.Wffeatures[1:] + issuetime = wffeature2.Issuetime + finishtime = wffeature2.Finishtime + se.issuetimeSum -= issuetime + se.finishtimeSum -= finishtime + se.mixSum -= finishtime * issuetime + se.issuetimeSquareSum -= issuetime * issuetime + se.intervaltimeSum -= (finishtime - issuetime) + } +} diff --git a/samplinglib/wfsampling.go b/samplinglib/wfsampling.go new file mode 100644 index 00000000..2f5536f0 --- /dev/null +++ b/samplinglib/wfsampling.go @@ -0,0 +1,137 @@ +package samplinglib + +import ( + "flag" + "log" + "time" + + "github.com/sarchlab/akita/v3/sim" +) + +// SampledRunnerFlag is used to enable wf sampling +var SampledRunnerFlag = flag.Bool("wf-sampling", false, "enable wavefront-level sampled simulation.") + +// SampledRunnerThresholdFlag is used to set the threshold of the sampling +var SampledRunnerThresholdFlag = flag.Float64("sampled-threshold", 0.03, + "the threshold of the sampled execution to enable sampling simulation.") + +// SampledRunnerGranularyFlag is used to set the granulary of the sampling +var SampledRunnerGranularyFlag = flag.Int("sampled-granulary", 1024, + "the granulary of the sampled execution to collect and analyze data.") + +// SampledEngine is used to detect if the wavefront sampling is stable or not. +type SampledEngine struct { + predTime sim.VTimeInSec + enableSampled bool + disableEngine bool + Simtime float64 `json:"simtime"` + Walltime float64 `json:"walltime"` + FullSimWalltime float64 `json:"fullsimwalltime"` + FullSimWalltimeStart time.Time + dataidx uint64 + stableEngine *StableEngine + shortStableEngine *StableEngine + predTimeSum sim.VTimeInSec + predTimeNum uint64 + granulary int +} + +// Reset all status +func (se *SampledEngine) Reset() { + se.FullSimWalltimeStart = time.Now() + se.stableEngine.Reset() + se.shortStableEngine.Reset() + se.predTime = 0 + se.predTimeNum = 0 + se.predTimeSum = 0 + se.dataidx = 0 + se.enableSampled = false +} + +// NewSampledEngine is used to new a sampled engine for wavefront sampling +func NewSampledEngine(granulary int, boundary float64, control bool) *SampledEngine { + stableEngine := &StableEngine{ + granulary: granulary, + boundary: boundary, + } + shortStableEngine := &StableEngine{ + granulary: granulary / 2, + boundary: boundary, + } + ret := &SampledEngine{ + stableEngine: stableEngine, + shortStableEngine: shortStableEngine, + granulary: granulary / 2, + } + ret.Reset() + if control { + ret.disableEngine = false + } + return ret +} + +// Sampledengine is used to monitor wavefront sampling +var Sampledengine *SampledEngine + +// InitSampledEngine is used to initial all status and data structure +func InitSampledEngine() { + Sampledengine = NewSampledEngine(*SampledRunnerGranularyFlag, *SampledRunnerThresholdFlag, false) + if *SampledRunnerFlag { + Sampledengine.Enable() + } else { + Sampledengine.Disabled() + } +} + +// Disabled the sampling engine +func (se *SampledEngine) Disabled() { + se.disableEngine = true +} + +// Enable the sampling engine +func (se *SampledEngine) Enable() { + se.disableEngine = false +} + +// IfDisable the sampling engine +func (se *SampledEngine) IfDisable() bool { + return se.disableEngine +} + +// Collect the runtime information +func (se *SampledEngine) Collect(issuetime sim.VTimeInSec, finishtime sim.VTimeInSec) { + if se.enableSampled || se.disableEngine { //we do not need to collect data if sampling is enabled + return + } + se.dataidx++ + if se.dataidx < 1024 { // discard the first 1024 data + return + } + se.stableEngine.Collect(issuetime, finishtime) + se.shortStableEngine.Collect(issuetime, finishtime) + stableEngine := se.stableEngine + shortStableEngine := se.shortStableEngine + if stableEngine.enableSampled { + longTime := stableEngine.predTime + shortTime := shortStableEngine.predTime + se.predTime = shortStableEngine.predTime + diff := float64((longTime - shortTime) / (longTime + shortTime)) + diffBoundary := *SampledRunnerThresholdFlag + if diff <= diffBoundary && diff >= -diffBoundary { + se.enableSampled = true + se.predTime = shortTime + se.predTimeSum = shortTime * sim.VTimeInSec(se.granulary) + se.predTimeNum = uint64(se.granulary) + } + } else if shortStableEngine.enableSampled { + se.predTime = stableEngine.predTime + } + if se.enableSampled { + log.Printf("Warp Sampling is enabled") + } +} + +// Predict the execution time of the next wavefronts +func (se *SampledEngine) Predict() (sim.VTimeInSec, bool) { + return se.predTime, se.enableSampled +} diff --git a/timing/cp/commandprocessor.go b/timing/cp/commandprocessor.go index eee71d94..064c2e7d 100644 --- a/timing/cp/commandprocessor.go +++ b/timing/cp/commandprocessor.go @@ -8,6 +8,7 @@ import ( "github.com/sarchlab/akita/v3/sim" "github.com/sarchlab/akita/v3/tracing" "github.com/sarchlab/mgpusim/v3/protocol" + "github.com/sarchlab/mgpusim/v3/samplinglib" "github.com/sarchlab/mgpusim/v3/timing/cp/internal/dispatching" "github.com/sarchlab/mgpusim/v3/timing/cp/internal/resource" "github.com/sarchlab/mgpusim/v3/timing/pagemigrationcontroller" @@ -297,7 +298,9 @@ func (p *CommandProcessor) processLaunchKernelReq( if d == nil { return false } - + if *samplinglib.SampledRunnerFlag { + samplinglib.Sampledengine.Reset() + } d.StartDispatching(req) p.ToDriver.Retrieve(now) diff --git a/timing/cp/internal/dispatching/dispatcher.go b/timing/cp/internal/dispatching/dispatcher.go index bb666f7e..66ec63a6 100644 --- a/timing/cp/internal/dispatching/dispatcher.go +++ b/timing/cp/internal/dispatching/dispatcher.go @@ -9,6 +9,7 @@ import ( "github.com/sarchlab/akita/v3/tracing" "github.com/sarchlab/mgpusim/v3/kernels" "github.com/sarchlab/mgpusim/v3/protocol" + "github.com/sarchlab/mgpusim/v3/samplinglib" "github.com/sarchlab/mgpusim/v3/timing/cp/internal/resource" ) @@ -113,6 +114,15 @@ func (d *DispatcherImpl) Tick(now sim.VTimeInSec) (madeProgress bool) { return madeProgress } +func (d *DispatcherImpl) collectSamplingData(locations []protocol.WfDispatchLocation) { + if *samplinglib.SampledRunnerFlag { + for _, l := range locations { + wavefront := l.Wavefront + samplinglib.Sampledengine.Collect(wavefront.Issuetime, wavefront.Finishtime) + } + } +} + func (d *DispatcherImpl) processMessagesFromCU(now sim.VTimeInSec) bool { msg := d.dispatchingPort.Peek() if msg == nil { @@ -123,9 +133,11 @@ func (d *DispatcherImpl) processMessagesFromCU(now sim.VTimeInSec) bool { case *protocol.WGCompletionMsg: count := 0 for _, rspToID := range msg.RspTo { - _, ok := d.inflightWGs[rspToID] + location, ok := d.inflightWGs[rspToID] if ok { count += 1 + ///sampling + d.collectSamplingData(location.locations) } } diff --git a/timing/cu/computeunit.go b/timing/cu/computeunit.go index d6d8f2a1..25840c5f 100644 --- a/timing/cu/computeunit.go +++ b/timing/cu/computeunit.go @@ -12,6 +12,7 @@ import ( "github.com/sarchlab/mgpusim/v3/insts" "github.com/sarchlab/mgpusim/v3/kernels" "github.com/sarchlab/mgpusim/v3/protocol" + "github.com/sarchlab/mgpusim/v3/samplinglib" "github.com/sarchlab/mgpusim/v3/timing/wavefront" ) @@ -71,6 +72,8 @@ type ComputeUnit struct { currentFlushReq *protocol.CUPipelineFlushReq currentRestartReq *protocol.CUPipelineRestartReq + //for sampling + wftime map[string]sim.VTimeInSec } // ControlPort returns the port that can receive controlling messages from the @@ -309,6 +312,56 @@ func (cu *ComputeUnit) processInputFromACE(now sim.VTimeInSec) bool { } } +// Handle the wavefront completion events +func (cu *ComputeUnit) Handle(evt sim.Event) error { + ctx := sim.HookCtx{ + Domain: cu, + Pos: sim.HookPosBeforeEvent, + Item: evt, + } + cu.InvokeHook(ctx) + + cu.Lock() + + defer cu.Unlock() + + switch evt := evt.(type) { + case *wavefront.WfCompletionEvent: + cu.handleWfCompletionEvent(evt) + default: + log.Panicf("Unable to process evevt of type %s", + reflect.TypeOf(evt)) + } + + ctx.Pos = sim.HookPosAfterEvent + cu.InvokeHook(ctx) + + return nil +} +func (cu *ComputeUnit) handleWfCompletionEvent(evt *wavefront.WfCompletionEvent) error { + wf := evt.Wf + wf.State = wavefront.WfCompleted + sTmp := cu.Scheduler + s := sTmp.(*SchedulerImpl) + if s.areAllOtherWfsInWGCompleted(wf.WG, wf) { + now := evt.Time() + + done := s.sendWGCompletionMessage(now, wf.WG) + if !done { + newEvent := wavefront.NewWfCompletionEvent(cu.Freq.NextTick(now), cu, wf) + cu.Engine.Schedule(newEvent) + return nil + } + + s.resetRegisterValue(wf) + cu.clearWGResource(wf.WG) + tracing.EndTask(wf.UID, cu) + tracing.TraceReqComplete(wf.WG.MapReq, cu) + + return nil + } + return nil +} func (cu *ComputeUnit) handleMapWGReq( now sim.VTimeInSec, req *protocol.MapWGReq, @@ -317,20 +370,47 @@ func (cu *ComputeUnit) handleMapWGReq( tracing.TraceReqReceive(req, cu) - for i, wf := range wg.Wfs { - location := req.Wavefronts[i] - cu.WfPools[location.SIMDID].AddWf(wf) - cu.WfDispatcher.DispatchWf(now, wf, req.Wavefronts[i]) - wf.State = wavefront.WfReady - - tracing.StartTaskWithSpecificLocation(wf.UID, - tracing.MsgIDAtReceiver(req, cu), - cu, - "wavefront", - "wavefront", - cu.Name()+".WFPool", - nil, - ) + //sampling + skipSimulate := false + if *samplinglib.SampledRunnerFlag { + for _, wf := range wg.Wfs { + cu.wftime[wf.UID] = now + } + wfpredicttime, wfsampled := samplinglib.Sampledengine.Predict() + predtime := wfpredicttime + skipSimulate = wfsampled + for _, wf := range wg.Wfs { + if skipSimulate { + predictedTime := predtime + now + wf.State = wavefront.WfSampledCompleted + newEvent := wavefront.NewWfCompletionEvent(predictedTime, cu, wf) + cu.Engine.Schedule(newEvent) + tracing.StartTask(wf.UID, + tracing.MsgIDAtReceiver(req, cu), + cu, + "wavefront", + "wavefront", + nil, + ) + } + } + } + if !skipSimulate { + for i, wf := range wg.Wfs { + location := req.Wavefronts[i] + cu.WfPools[location.SIMDID].AddWf(wf) + cu.WfDispatcher.DispatchWf(now, wf, req.Wavefronts[i]) + wf.State = wavefront.WfReady + + tracing.StartTaskWithSpecificLocation(wf.UID, + tracing.MsgIDAtReceiver(req, cu), + cu, + "wavefront", + "wavefront", + cu.Name()+".WFPool", + nil, + ) + } } cu.running = true @@ -809,6 +889,6 @@ func NewComputeUnit( cu.ToScalarMem = sim.NewLimitNumMsgPort(cu, 4, name+".ToScalarMem") cu.ToVectorMem = sim.NewLimitNumMsgPort(cu, 4, name+".ToVectorMem") cu.ToCP = sim.NewLimitNumMsgPort(cu, 4, name+".ToCP") - + cu.wftime = make(map[string]sim.VTimeInSec) return cu } diff --git a/timing/cu/scheduler.go b/timing/cu/scheduler.go index e8edc542..157ea2a2 100644 --- a/timing/cu/scheduler.go +++ b/timing/cu/scheduler.go @@ -8,6 +8,7 @@ import ( "github.com/sarchlab/akita/v3/tracing" "github.com/sarchlab/mgpusim/v3/insts" "github.com/sarchlab/mgpusim/v3/protocol" + "github.com/sarchlab/mgpusim/v3/samplinglib" "github.com/sarchlab/mgpusim/v3/timing/wavefront" ) @@ -270,7 +271,16 @@ func (s *SchedulerImpl) evalSEndPgm( wf.OutstandingScalarMemAccess > 0 { return false, false } - + ////sampling + if *samplinglib.SampledRunnerFlag { + issuetime, found := s.cu.wftime[wf.UID] + if found { + finishtime := now + wf.Finishtime = finishtime + wf.Issuetime = issuetime + delete(s.cu.wftime, wf.UID) + } + } if s.areAllOtherWfsInWGCompleted(wf.WG, wf) { done := s.sendWGCompletionMessage(now, wf.WG) if !done { diff --git a/timing/wavefront/wavefront.go b/timing/wavefront/wavefront.go index b95ae2ad..618331e9 100644 --- a/timing/wavefront/wavefront.go +++ b/timing/wavefront/wavefront.go @@ -15,11 +15,12 @@ type WfState int // A list of all possible WfState const ( - WfDispatching WfState = iota // Dispatching in progress, not ready to run - WfReady // Allow the scheduler to schedule instruction - WfRunning // Instruction in fight - WfCompleted // Wavefront completed - WfAtBarrier // Wavefront at barrier + WfDispatching WfState = iota // Dispatching in progress, not ready to run + WfReady // Allow the scheduler to schedule instruction + WfRunning // Instruction in fight + WfCompleted // Wavefront completed + WfAtBarrier // Wavefront at barrier + WfSampledCompleted // Wavefront completed at Sampling ) // A Wavefront in the timing package contains the information of the progress diff --git a/timing/wavefront/wfcompletionevent.go b/timing/wavefront/wfcompletionevent.go new file mode 100644 index 00000000..dec10f35 --- /dev/null +++ b/timing/wavefront/wfcompletionevent.go @@ -0,0 +1,24 @@ +package wavefront + +import ( + "github.com/sarchlab/akita/v3/sim" + // "gitlab.com/akita/mgpusim/v3/timing/wavefront" +) + +// A WfCompletionEvent marks the completion of a wavefront +type WfCompletionEvent struct { + *sim.EventBase + Wf *Wavefront +} + +// NewWfCompletionEvent returns a newly constructed WfCompleteEvent +func NewWfCompletionEvent( + time sim.VTimeInSec, + handler sim.Handler, + wf *Wavefront, +) *WfCompletionEvent { + evt := new(WfCompletionEvent) + evt.EventBase = sim.NewEventBase(time, handler) + evt.Wf = wf + return evt +}