Refactor PID control to simplify the code and modifications

Each PID uses its internal dt, so may be various contexts with different rate.
PID has max dt, so no need to reset explicitly.
This commit is contained in:
Oleg Kalachev
2025-10-20 22:54:18 +03:00
parent d06eb2a1aa
commit ca595edce5
2 changed files with 30 additions and 38 deletions

View File

@@ -9,7 +9,6 @@
#include "lpf.h" #include "lpf.h"
#include "util.h" #include "util.h"
#define ARMED_THRUST 0.1 // thrust to indicate armed state
#define PITCHRATE_P 0.05 #define PITCHRATE_P 0.05
#define PITCHRATE_I 0.2 #define PITCHRATE_I 0.2
#define PITCHRATE_D 0.001 #define PITCHRATE_D 0.001
@@ -100,12 +99,7 @@ void interpretControls() {
} }
void controlAttitude() { void controlAttitude() {
if (!armed || attitudeTarget.invalid()) { // skip attitude control if (!armed || attitudeTarget.invalid()) return; // skip attitude control
rollPID.reset();
pitchPID.reset();
yawPID.reset();
return;
}
const Vector up(0, 0, 1); const Vector up(0, 0, 1);
Vector upActual = Quaternion::rotateVector(up, attitude); Vector upActual = Quaternion::rotateVector(up, attitude);
@@ -113,28 +107,23 @@ void controlAttitude() {
Vector error = Vector::rotationVectorBetween(upTarget, upActual); Vector error = Vector::rotationVectorBetween(upTarget, upActual);
ratesTarget.x = rollPID.update(error.x, dt) + ratesExtra.x; ratesTarget.x = rollPID.update(error.x) + ratesExtra.x;
ratesTarget.y = pitchPID.update(error.y, dt) + ratesExtra.y; ratesTarget.y = pitchPID.update(error.y) + ratesExtra.y;
float yawError = wrapAngle(attitudeTarget.getYaw() - attitude.getYaw()); float yawError = wrapAngle(attitudeTarget.getYaw() - attitude.getYaw());
ratesTarget.z = yawPID.update(yawError, dt) + ratesExtra.z; ratesTarget.z = yawPID.update(yawError) + ratesExtra.z;
} }
void controlRates() { void controlRates() {
if (!armed || ratesTarget.invalid()) { // skip rates control if (!armed || ratesTarget.invalid()) return; // skip rates control
rollRatePID.reset();
pitchRatePID.reset();
yawRatePID.reset();
return;
}
Vector error = ratesTarget - rates; Vector error = ratesTarget - rates;
// Calculate desired torque, where 0 - no torque, 1 - maximum possible torque // Calculate desired torque, where 0 - no torque, 1 - maximum possible torque
torqueTarget.x = rollRatePID.update(error.x, dt); torqueTarget.x = rollRatePID.update(error.x);
torqueTarget.y = pitchRatePID.update(error.y, dt); torqueTarget.y = pitchRatePID.update(error.y);
torqueTarget.z = yawRatePID.update(error.z, dt); torqueTarget.z = yawRatePID.update(error.z);
} }
void controlTorque() { void controlTorque() {
@@ -145,12 +134,11 @@ void controlTorque() {
return; return;
} }
if (thrustTarget < 0.05) { if (thrustTarget < 0.1) {
// minimal thrust to indicate armed state motors[0] = 0.1; // idle thrust
motors[0] = ARMED_THRUST; motors[1] = 0.1;
motors[1] = ARMED_THRUST; motors[2] = 0.1;
motors[2] = ARMED_THRUST; motors[3] = 0.1;
motors[3] = ARMED_THRUST;
return; return;
} }

View File

@@ -9,40 +9,44 @@
class PID { class PID {
public: public:
float p = 0; float p, i, d;
float i = 0; float windup;
float d = 0; float dtMax;
float windup = 0;
float derivative = 0; float derivative = 0;
float integral = 0; float integral = 0;
LowPassFilter<float> lpf; // low pass filter for derivative term LowPassFilter<float> lpf; // low pass filter for derivative term
PID(float p, float i, float d, float windup = 0, float dAlpha = 1) : p(p), i(i), d(d), windup(windup), lpf(dAlpha) {}; PID(float p, float i, float d, float windup = 0, float dAlpha = 1, float dtMax = 0.1) :
p(p), i(i), d(d), windup(windup), lpf(dAlpha), dtMax(dtMax) {}
float update(float error, float dt) { float update(float error) {
integral += error * dt; float dt = t - prevTime;
if (isfinite(prevError) && dt > 0) { if (dt > 0 && dt < dtMax) {
// calculate derivative if both dt and prevError are valid integral += error * dt;
derivative = (error - prevError) / dt; derivative = lpf.update((error - prevError) / dt); // compute derivative and apply low-pass filter
} else {
// apply low pass filter to derivative integral = 0;
derivative = lpf.update(derivative); derivative = 0;
} }
prevError = error; prevError = error;
prevTime = t;
return p * error + constrain(i * integral, -windup, windup) + d * derivative; // PID return p * error + constrain(i * integral, -windup, windup) + d * derivative; // PID
} }
void reset() { void reset() {
prevError = NAN; prevError = NAN;
prevTime = NAN;
integral = 0; integral = 0;
derivative = 0; derivative = 0;
lpf.reset();
} }
private: private:
float prevError = NAN; float prevError = NAN;
float prevTime = NAN;
}; };