//go:build linux package osimage import ( "bufio" "context" "fmt" "os" "path/filepath" "strconv" "strings" "sync" "time" "golang.org/x/sys/unix" ) type adaptiveWriteController struct { mu sync.Mutex limiter *rateLimiter monitor *diskBusyMonitor sampleInterval time.Duration nextSampleAt time.Time minBPS int64 maxBPS int64 busyHighPct float64 busyLowPct float64 } func newAdaptiveWriteController(targetPath string) (*adaptiveWriteController, error) { mon, err := newDiskBusyMonitor(targetPath) if err != nil { return nil, err } now := time.Now() return &adaptiveWriteController{ limiter: newRateLimiter(defaultInitialWriteBPS, defaultBurstBytes), monitor: mon, sampleInterval: defaultSampleInterval, nextSampleAt: now.Add(defaultSampleInterval), minBPS: defaultMinWriteBPS, maxBPS: defaultMaxWriteBPS, busyHighPct: defaultBusyHighPct, busyLowPct: defaultBusyLowPct, }, nil } func newNoopAdaptiveWriteController() *adaptiveWriteController { return &adaptiveWriteController{ limiter: newRateLimiter(0, 0), sampleInterval: defaultSampleInterval, } } func (c *adaptiveWriteController) Wait(ctx context.Context, n int) error { if c == nil || c.limiter == nil { return nil } return c.limiter.Wait(ctx, n) } func (c *adaptiveWriteController) ObserveWrite(n int) { c.observe(false) } func (c *adaptiveWriteController) ObserveSync() { c.observe(true) } func (c *adaptiveWriteController) observe(afterSync bool) { if c == nil { return } c.mu.Lock() defer c.mu.Unlock() now := time.Now() if c.monitor == nil || now.Before(c.nextSampleAt) { return } c.nextSampleAt = now.Add(c.sampleInterval) s, err := c.monitor.Sample(now) if err != nil { return } cur := c.limiter.Rate() if cur <= 0 { cur = c.minBPS } switch { case s.UtilPct >= c.busyHighPct || s.Await >= defaultSlowAwait || afterSync: // Back off aggressively when the disk is obviously suffering. next := cur / 2 if next < c.minBPS { next = c.minBPS } c.limiter.SetRate(next) case s.UtilPct <= c.busyLowPct && s.Await <= defaultFastAwait: // Recover slowly. next := cur + (cur / 5) // +20% if next > c.maxBPS { next = c.maxBPS } c.limiter.SetRate(next) } } type rateLimiter struct { mu sync.Mutex rateBPS int64 burst int64 tokens float64 last time.Time } func newRateLimiter(rateBPS, burst int64) *rateLimiter { now := time.Now() if burst < 0 { burst = 0 } return &rateLimiter{ rateBPS: rateBPS, burst: burst, tokens: float64(burst), last: now, } } func (r *rateLimiter) Rate() int64 { r.mu.Lock() defer r.mu.Unlock() return r.rateBPS } func (r *rateLimiter) SetRate(rateBPS int64) { r.mu.Lock() defer r.mu.Unlock() r.refillLocked(time.Now()) r.rateBPS = rateBPS if rateBPS <= 0 { r.tokens = 0 r.burst = 0 return } // Keep burst small and fixed. Do not let burst scale with rate. r.burst = defaultBurstBytes if r.tokens > float64(r.burst) { r.tokens = float64(r.burst) } } func (r *rateLimiter) Wait(ctx context.Context, n int) error { if n <= 0 { return nil } remaining := n for remaining > 0 { r.mu.Lock() if r.rateBPS <= 0 { r.mu.Unlock() return nil } now := time.Now() r.refillLocked(now) allowed := remaining if int64(allowed) > r.burst && r.burst > 0 { allowed = int(r.burst) } if allowed <= 0 { allowed = remaining } if r.tokens >= float64(allowed) { r.tokens -= float64(allowed) r.mu.Unlock() remaining -= allowed continue } missing := float64(allowed) - r.tokens waitDur := time.Duration(missing / float64(r.rateBPS) * float64(time.Second)) if waitDur < 5*time.Millisecond { waitDur = 5 * time.Millisecond } r.mu.Unlock() timer := time.NewTimer(waitDur) select { case <-ctx.Done(): timer.Stop() return ctx.Err() case <-timer.C: } } return nil } func (r *rateLimiter) refillLocked(now time.Time) { if r.rateBPS <= 0 { r.last = now return } elapsed := now.Sub(r.last) if elapsed <= 0 { return } r.tokens += elapsed.Seconds() * float64(r.rateBPS) if r.tokens > float64(r.burst) { r.tokens = float64(r.burst) } r.last = now } type diskBusySample struct { UtilPct float64 Await time.Duration } type diskBusyMonitor struct { major int minor int lastAt time.Time lastIOMs uint64 lastWrites uint64 } func newDiskBusyMonitor(targetPath string) (*diskBusyMonitor, error) { major, minor, err := resolveWholeDiskMajorMinor(targetPath) if err != nil { return nil, err } ioMs, writes, err := readDiskStats(major, minor) if err != nil { return nil, err } return &diskBusyMonitor{ major: major, minor: minor, lastAt: time.Now(), lastIOMs: ioMs, lastWrites: writes, }, nil } func (m *diskBusyMonitor) Sample(now time.Time) (diskBusySample, error) { ioMs, writes, err := readDiskStats(m.major, m.minor) if err != nil { return diskBusySample{}, err } elapsedMs := now.Sub(m.lastAt).Milliseconds() if elapsedMs <= 0 { return diskBusySample{}, nil } deltaIOMs := int64(ioMs - m.lastIOMs) deltaWrites := int64(writes - m.lastWrites) m.lastAt = now m.lastIOMs = ioMs m.lastWrites = writes util := float64(deltaIOMs) * 100 / float64(elapsedMs) if util < 0 { util = 0 } if util > 100 { util = 100 } var await time.Duration if deltaWrites > 0 { await = time.Duration(deltaIOMs/int64(deltaWrites)) * time.Millisecond } return diskBusySample{ UtilPct: util, Await: await, }, nil } func resolveWholeDiskMajorMinor(targetPath string) (int, int, error) { var st unix.Stat_t if err := unix.Stat(targetPath, &st); err != nil { return 0, 0, fmt.Errorf("stat target %q: %w", targetPath, err) } if st.Mode&unix.S_IFMT != unix.S_IFBLK { return 0, 0, fmt.Errorf("target %q is not a block device", targetPath) } major := int(unix.Major(uint64(st.Rdev))) minor := int(unix.Minor(uint64(st.Rdev))) sysfsPath := fmt.Sprintf("/sys/dev/block/%d:%d", major, minor) resolved, err := filepath.EvalSymlinks(sysfsPath) if err != nil { return major, minor, nil } // Partition path usually looks like .../block/sda/sda3 // Parent whole disk is .../block/sda parent := filepath.Dir(resolved) devName := filepath.Base(parent) ueventPath := filepath.Join(parent, "dev") data, err := os.ReadFile(ueventPath) if err != nil { return major, minor, nil } parts := strings.Split(strings.TrimSpace(string(data)), ":") if len(parts) != 2 { return major, minor, nil } parentMajor, err1 := strconv.Atoi(parts[0]) parentMinor, err2 := strconv.Atoi(parts[1]) if err1 != nil || err2 != nil || devName == "" { return major, minor, nil } return parentMajor, parentMinor, nil } func readDiskStats(major, minor int) (ioMs uint64, writesCompleted uint64, err error) { f, err := os.Open("/proc/diskstats") if err != nil { return 0, 0, fmt.Errorf("open /proc/diskstats: %w", err) } defer f.Close() sc := bufio.NewScanner(f) for sc.Scan() { line := strings.Fields(sc.Text()) if len(line) < 14 { continue } maj, err := strconv.Atoi(line[0]) if err != nil { continue } min, err := strconv.Atoi(line[1]) if err != nil { continue } if maj != major || min != minor { continue } // writes completed successfully: field 5, index 4 writesCompleted, err = strconv.ParseUint(line[4], 10, 64) if err != nil { return 0, 0, fmt.Errorf("parse writes completed for %d:%d: %w", major, minor, err) } // time spent doing I/Os (ms): field 13, index 12 ioMs, err = strconv.ParseUint(line[12], 10, 64) if err != nil { return 0, 0, fmt.Errorf("parse io_ms for %d:%d: %w", major, minor, err) } return ioMs, writesCompleted, nil } if err := sc.Err(); err != nil { return 0, 0, fmt.Errorf("scan /proc/diskstats: %w", err) } return 0, 0, fmt.Errorf("device %d:%d not found in /proc/diskstats", major, minor) }